Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d0c9d6b
Unverified
Commit
1d0c9d6b
authored
May 06, 2025
by
Jinzhen Lin
Committed by
GitHub
May 05, 2025
Browse files
[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by:
Jinzhen Lin
<
linjinzhen@hotmail.com
>
parent
f62cad64
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3123 additions
and
3145 deletions
+3123
-3145
CMakeLists.txt
CMakeLists.txt
+46
-2
csrc/moe/marlin_moe_wna16/.gitignore
csrc/moe/marlin_moe_wna16/.gitignore
+1
-0
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+12
-8
csrc/moe/marlin_moe_wna16/kernel.h
csrc/moe/marlin_moe_wna16/kernel.h
+4
-6
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+208
-260
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+155
-190
csrc/quantization/gptq_marlin/.gitignore
csrc/quantization/gptq_marlin/.gitignore
+1
-0
csrc/quantization/gptq_marlin/dequant.h
csrc/quantization/gptq_marlin/dequant.h
+291
-0
csrc/quantization/gptq_marlin/generate_kernels.py
csrc/quantization/gptq_marlin/generate_kernels.py
+116
-0
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+442
-2063
csrc/quantization/gptq_marlin/kernel.h
csrc/quantization/gptq_marlin/kernel.h
+37
-0
csrc/quantization/gptq_marlin/marlin_template.h
csrc/quantization/gptq_marlin/marlin_template.h
+1678
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-13
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+44
-121
tests/kernels/quantization/test_awq_marlin.py
tests/kernels/quantization/test_awq_marlin.py
+0
-164
tests/kernels/quantization/test_marlin_gemm.py
tests/kernels/quantization/test_marlin_gemm.py
+38
-103
vllm/_custom_ops.py
vllm/_custom_ops.py
+14
-31
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+19
-170
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+11
-13
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+2
-1
No files found.
CMakeLists.txt
View file @
1d0c9d6b
...
...
@@ -301,8 +301,52 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# are not supported by Machete yet.
cuda_archs_loose_intersection
(
MARLIN_ARCHS
"8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0"
"
${
CUDA_ARCHS
}
"
)
if
(
MARLIN_ARCHS
)
#
# For the Marlin kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set
(
MARLIN_GEN_SCRIPT
${
CMAKE_CURRENT_SOURCE_DIR
}
/csrc/quantization/gptq_marlin/generate_kernels.py
)
file
(
MD5
${
MARLIN_GEN_SCRIPT
}
MARLIN_GEN_SCRIPT_HASH
)
message
(
STATUS
"Marlin generation script hash:
${
MARLIN_GEN_SCRIPT_HASH
}
"
)
message
(
STATUS
"Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}"
)
if
(
NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL
${
MARLIN_GEN_SCRIPT_HASH
}
)
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E env
PYTHONPATH=$PYTHONPATH
${
Python_EXECUTABLE
}
${
MARLIN_GEN_SCRIPT
}
RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE
${
CMAKE_CURRENT_BINARY_DIR
}
/marlin_generation.log
ERROR_FILE
${
CMAKE_CURRENT_BINARY_DIR
}
/marlin_generation.log
)
if
(
NOT marlin_generation_result EQUAL 0
)
message
(
FATAL_ERROR
"Marlin generation failed."
" Result:
\"
${
marlin_generation_result
}
\"
"
"
\n
Check the log for details: "
"
${
CMAKE_CURRENT_BINARY_DIR
}
/marlin_generation.log"
)
else
()
set
(
MARLIN_GEN_SCRIPT_HASH
${
MARLIN_GEN_SCRIPT_HASH
}
CACHE STRING
"Last run Marlin generate script hash"
FORCE
)
message
(
STATUS
"Marlin generation completed successfully."
)
endif
()
else
()
message
(
STATUS
"Marlin generation script has not changed, skipping generation."
)
endif
()
file
(
GLOB MARLIN_TEMPLATE_KERNEL_SRC
"csrc/quantization/gptq_marlin/kernel_*.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
MARLIN_TEMPLATE_KERNEL_SRC
}
"
CUDA_ARCHS
"
${
MARLIN_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
${
MARLIN_TEMPLATE_KERNEL_SRC
}
)
set
(
MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
...
...
@@ -644,7 +688,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL
${
MOE_MARLIN_GEN_SCRIPT_HASH
}
)
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E env
PYTHONPATH=
$
{
CMAKE_CURRENT_SOURCE_DIR
}
/csrc/cutlass_extensions/:
${
CUTLASS_DIR
}
/python/:
${
VLLM_PYTHON_PATH
}
:$
PYTHONPATH
PYTHONPATH=$PYTHONPATH
${
Python_EXECUTABLE
}
${
MOE_MARLIN_GEN_SCRIPT
}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output
...
...
csrc/moe/marlin_moe_wna16/.gitignore
0 → 100644
View file @
1d0c9d6b
kernel_*.cu
\ No newline at end of file
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
1d0c9d6b
...
...
@@ -25,15 +25,13 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
]
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
)]
THREAD_M_BLOCKS
=
[
0.5
,
1
,
2
,
3
,
4
]
...
...
@@ -52,21 +50,29 @@ def remove_old_kernels():
def
generate_new_kernels
():
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
has_zp
=
"B"
not
in
scalar_type
all_template_str_list
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
has_act_order
=
group_blocks
==
0
if
has_zp
and
has_act_order
:
# act order case only support gptq-int4 and gptq-int8
if
group_blocks
==
0
and
scalar_type
not
in
[
"vllm::kU4B8"
,
"vllm::kU8B128"
]:
continue
if
thread_configs
[
2
]
==
256
:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if
m_blocks
<=
1
and
thread_configs
[
0
]
!=
128
:
continue
if
m_blocks
>
1
and
thread_configs
[
0
]
!=
64
:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if
scalar_type
==
"vllm::kFE4M3fn"
and
group_blocks
not
in
[
-
1
,
8
]:
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
threads
=
thread_configs
[
2
]
...
...
@@ -82,8 +88,6 @@ def generate_new_kernels():
thread_k_blocks
=
k_blocks
,
m_block_size_8
=
m_blocks
==
0.5
,
stages
=
"pipe_stages"
,
has_act_order
=
has_act_order
,
has_zp
=
has_zp
,
group_blocks
=
group_blocks
,
is_zp_float
=
False
,
)
...
...
csrc/moe/marlin_moe_wna16/kernel.h
View file @
1d0c9d6b
...
...
@@ -18,7 +18,7 @@
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
bool use_fp32_reduce
, int max_shared_mem
namespace
MARLIN_NAMESPACE_NAME
{
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
...
...
@@ -33,11 +33,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
MARLIN_KERNEL_PARAMS
);
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
1d0c9d6b
...
...
@@ -25,6 +25,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -77,8 +76,8 @@ __global__ void Marlin(
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
bool
use_fp32_reduce
,
// whether to use fp32 global reduce
int
max_shared_mem
)
{}
}
// namespace MARLIN_NAMESPACE_NAME
...
...
@@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
template
<
typename
scalar_t
,
int
bit
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant
(
int
q
,
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
4
>
(
int
q
,
typename
ScalarType
<
half
>::
FragB
&
frag_b
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
4
>
(
int
q
,
typename
ScalarType
<
nv_bfloat16
>::
FragB
&
frag_b
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
8
>
(
int
q
,
typename
ScalarType
<
half
>::
FragB
&
frag_b
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
8
>
(
int
q
,
typename
ScalarType
<
nv_bfloat16
>::
FragB
&
frag_b
)
{
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388736.
f
;
fp32_intermediates
[
1
]
-=
8388736.
f
;
fp32_intermediates
[
2
]
-=
8388736.
f
;
fp32_intermediates
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
...
...
@@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -458,8 +317,8 @@ __global__ void Marlin(
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
bool
use_fp32_reduce
,
// whether to use fp32 global reduce
int
max_shared_mem
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
...
...
@@ -481,6 +340,8 @@ __global__ void Marlin(
extern
__shared__
int4
sh
[];
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
bool
has_zp
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
;
constexpr
bool
has_act_order
=
group_blocks
==
0
;
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
static_assert
(
thread_m_blocks
==
1
||
!
m_block_size_8
);
...
...
@@ -534,13 +395,20 @@ __global__ void Marlin(
int64_t
B_expert_off
=
0
;
int4
*
sh_block_sorted_ids_int4
=
sh
;
int4
*
sh_rd_block_sorted_ids_int4
=
sh_block_sorted_ids_int4
+
moe_block_size
/
4
;
int4
*
sh_block_topk_weights_int4
=
sh_rd_block_sorted_ids_int4
+
moe_block_size
/
4
;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4
*
sh_new
=
sh_block_topk_weights_int4
+
moe_block_size
/
2
+
moe_block_size
;
int32_t
*
sh_block_sorted_ids
=
reinterpret_cast
<
int
*>
(
sh_block_sorted_ids_int4
);
int
4
*
sh_block_
topk_weights_int4
=
sh_block_sorted_ids_int4
+
moe_block_size
/
4
;
int
32_t
*
sh_
rd_
block_
sorted_ids
=
reinterpret_cast
<
int
*>
(
sh_rd_block_sorted_ids_int4
)
;
scalar_t2
*
sh_block_topk_weights
=
reinterpret_cast
<
scalar_t2
*>
(
sh_block_topk_weights_int4
);
int4
*
sh_new
=
sh_block_topk_weights_int4
+
moe_block_size
/
4
;
int32_t
block_num_valid_tokens
=
0
;
int32_t
locks_off
=
0
;
...
...
@@ -584,6 +452,11 @@ __global__ void Marlin(
sh_block_sorted_ids_int4
[
tid4
]
=
reinterpret_cast
<
const
int4
*>
(
sorted_token_ids_ptr
)[
block_id
*
moe_block_size
/
4
+
tid4
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
sh_rd_block_sorted_ids
[
tid4
*
4
+
i
]
=
sh_block_sorted_ids
[
tid4
*
4
+
i
]
/
top_k
;
if
(
mul_topk_weights
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
...
@@ -743,6 +616,7 @@ __global__ void Marlin(
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
constexpr
int
act_s_max_num_groups
=
32
;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
...
...
@@ -758,9 +632,9 @@ __global__ void Marlin(
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
)
;
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
int
a_gl_rd
_row
=
threadIdx
.
x
/
a_gl_rd_delta_o
;
int
a_gl_rd_col
=
a_gl_rd_delta_o
*
slice_row
+
threadIdx
.
x
%
a_gl_rd_delta_o
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
...
...
@@ -774,8 +648,8 @@ __global__ void Marlin(
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
...
...
@@ -794,7 +668,7 @@ __global__ void Marlin(
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
s_sh_wr
=
threadIdx
.
x
;
auto
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
...
...
@@ -807,7 +681,7 @@ __global__ void Marlin(
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
auto
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
...
...
@@ -851,7 +725,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
(
row
%
8
)
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
...
...
@@ -879,12 +753,28 @@ __global__ void Marlin(
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh_new
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
constexpr
int
sh_red_size
=
(
2
*
thread_n_blocks
+
1
)
*
16
*
thread_m_blocks
;
constexpr
int
sh_b_size
=
stages
*
b_sh_stage
;
int4
*
sh_b
=
sh_new
;
int4
*
sh_red
=
sh_new
;
int4
*
sh_g_idx
=
sh_b
+
(
sh_red_size
>
sh_b_size
?
sh_red_size
:
sh_b_size
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
constexpr
int
sh_s_size
=
has_act_order
?
(
act_s_max_num_groups
*
s_sh_stride
)
:
(
stages
*
s_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_red
=
sh_b
;
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert
(
thread_m_blocks
*
16
*
thread_n_blocks
*
16
/
8
<=
stages
*
b_sh_stage
);
int4
*
sh_a
=
sh_s
+
sh_s_size
;
constexpr
int
shm_size_used
=
moe_block_size
+
stages
*
(
g_idx_stage
+
zp_sh_stage
)
+
sh_s_size
+
(
sh_red_size
>
sh_b_size
?
sh_red_size
:
sh_b_size
);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int
sh_a_max_row
=
((
max_shared_mem
-
1024
)
/
16
-
shm_size_used
)
/
(
thread_k_blocks
*
2
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
...
...
@@ -905,15 +795,14 @@ __global__ void Marlin(
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_act_order_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
s
h
_max_num_groups
)
{
sh_num_groups
=
s
h
_max_num_groups
;
if
(
sh_num_groups
<
act_
s_max_num_groups
)
{
sh_num_groups
=
act_
s_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
...
...
@@ -940,27 +829,31 @@ __global__ void Marlin(
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
int
a_remaining_load_count_in_slice
=
stages
;
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
bool
should_load_a
=
true
;
int
max_num_stage_groups
=
((
sh_a_max_row
-
moe_block_size
)
/
moe_block_size
+
1
)
/
stages
;
max_num_stage_groups
=
max
(
max_num_stage_groups
,
1
);
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
,
int
pipe_a
=
0
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
if
(
prob_k
>
thread_k_blocks
*
16
*
stages
||
slice_col
==
0
||
a_remaining_load_count_in_slice
>
0
)
{
a_remaining_load_count_in_slice
--
;
if
(
should_load_a
)
{
int4
*
sh_a_stage
=
sh_a
+
moe_block_size
*
a_sh_stride
*
pipe_a
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
int
a_idx
=
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
;
int
row
=
a_idx
/
a_gl_stride
;
int
row
=
a_gl_rd_delta_i
/
a_gl_stride
*
i
+
a_gl_rd_row
;
int64_t
sorted_row
=
0
;
if
(
!
m_block_size_8
||
row
<
8
)
sorted_row
=
sh_block_sorted_ids
[
row
]
/
top_k
;
int64_t
true_idx
=
sorted_row
*
a_gl_stride
+
a_idx
%
a_gl_stride
;
sorted_row
=
sh_rd_block_sorted_ids
[
row
];
int64_t
true_idx
=
sorted_row
*
a_gl_stride
+
a_gl_rd_col
+
a_gl_rd_delta_o
*
a_off
;
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
true_idx
],
row
<
block_num_valid_tokens
);
}
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
...
...
@@ -1063,8 +956,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_st
ag
e
*
pipe
;
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
,
int
pipe_a
=
0
)
{
int4
*
sh_a_stage
=
sh_a
+
moe_block_size
*
a_sh_st
rid
e
*
pipe
_a
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm
<
m_block_size_8
?
2
:
4
,
scalar_t
>
(
...
...
@@ -1109,12 +1002,17 @@ __global__ void Marlin(
}
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1152,7 +1050,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
...
...
@@ -1161,7 +1059,7 @@ __global__ void Marlin(
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
auto
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
...
...
@@ -1222,15 +1120,18 @@ __global__ void Marlin(
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1251,6 +1152,7 @@ __global__ void Marlin(
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
...
...
@@ -1263,12 +1165,16 @@ __global__ void Marlin(
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1292,6 +1198,25 @@ __global__ void Marlin(
}
};
auto
dequant_data
=
[
&
](
int
q
,
scalar_t2
*
frag_b_ptr
)
{
if
constexpr
(
has_zp
&&
is_zp_float
||
!
has_zp
)
{
dequant
<
scalar_t2
,
w_type_id
>
(
q
,
frag_b_ptr
);
}
else
{
static_assert
(
has_zp
&&
!
is_zp_float
);
static_assert
(
w_type_id
==
vllm
::
kU4
.
id
()
||
w_type_id
==
vllm
::
kU8
.
id
());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if
constexpr
(
w_type_id
==
vllm
::
kU4
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU4B8
.
id
()
>
(
q
,
frag_b_ptr
);
}
else
if
constexpr
(
w_type_id
==
vllm
::
kU8
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU8B128
.
id
()
>
(
q
,
frag_b_ptr
);
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
bool
is_first_matmul_in_slice
=
true
;
auto
matmul
=
[
&
](
int
k
)
{
...
...
@@ -1315,15 +1240,17 @@ __global__ void Marlin(
zp_quant_1
=
frag_qzp
[
k2
][
1
];
}
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
zp_quant_0
,
frag_zp_0
);
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
zp_quant_1
,
frag_zp_1
);
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
dequant_data
(
zp_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
));
dequant_data
(
zp_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
)
+
2
);
}
}
if
constexpr
(
has_zp
&&
is_zp_float
)
{
if
(
is_new_zp
)
{
reinterpret_cast
<
int4
*>
(
&
frag_zp
)[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k2
])[
0
];
}
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
...
...
@@ -1342,8 +1269,8 @@ __global__ void Marlin(
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
}
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
b_quant_0
,
frag_b0
);
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
b_quant_1
,
frag_b1
);
dequant
_data
(
b_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b0
)
)
;
dequant
_data
(
b_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b1
)
)
;
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
...
...
@@ -1351,8 +1278,7 @@ __global__ void Marlin(
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
0
);
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
int
idx
=
(
threadIdx
.
x
/
4
)
%
2
;
scalar_t2
s2
=
Dtype
::
nums2num2
(
...
...
@@ -1361,18 +1287,12 @@ __global__ void Marlin(
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
s2
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
s2
.
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
s2
.
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
!=
-
1
)
{
}
else
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k
%
2
][
j
][
0
].
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k
%
2
][
j
][
0
].
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_zp
&&
is_zp_float
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zpf
[
k2
][
j
]
=
__hmul2
(
frag_zpf
[
k2
][
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
].
x
,
frag_zpf
[
k2
][
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
].
y
,
frag_zpf
[
k2
][
j
].
y
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
][
0
].
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
][
0
].
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
],
0
);
scale
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
],
1
);
...
...
@@ -1397,7 +1317,7 @@ __global__ void Marlin(
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
...
...
@@ -1731,7 +1651,7 @@ __global__ void Marlin(
fetch_col_scale_to_shared
();
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
,
i
);
}
zero_accums
();
...
...
@@ -1740,8 +1660,10 @@ __global__ void Marlin(
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
a_gl_rd_col
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
if
constexpr
(
has_act_order
)
{
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
}
};
if
(
slice_iters
)
{
start_pipes
();
...
...
@@ -1754,43 +1676,56 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at
// index 0.
for
(
int
stage_group_id
=
0
;
stage_group_id
<
max_num_stage_groups
;
stage_group_id
++
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
int
idx
=
(
pipe
>=
stages
&&
stage_group_id
==
max_num_stage_groups
-
1
)
?
(
pipe
-
stages
)
:
(
pipe
+
stage_group_id
*
stages
);
fetch_to_registers
(
k
+
1
,
pipe
%
stages
,
idx
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
int
idx
=
(
pipe
>=
1
&&
stage_group_id
==
max_num_stage_groups
-
1
)
?
(
pipe
-
1
)
:
(
pipe
+
(
stage_group_id
+
1
)
*
stages
-
1
);
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
,
idx
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_remaining_load_count_in_slice
=
0
;
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
a_gl_rd_col
+=
a_gl_rd_delta_o
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
if
constexpr
(
has_act_order
)
{
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_act_order_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_act_order_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
if
(
slice_iters
==
0
)
{
break
;
}
}
...
...
@@ -1877,15 +1812,30 @@ __global__ void Marlin(
if
(
last
||
use_atomic_add
)
// only the last block in a slice actually writes the result
write_result
();
i
f
(
slice_row
)
a_remaining_load_count_in_slice
=
stages
;
i
nt
old_
slice_row
=
slice_row
;
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
is_first_matmul_in_slice
=
true
;
init_slice
();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if
(
slice_col
==
0
||
old_slice_row
||
prob_k
>
thread_k_blocks
*
16
*
stages
*
max_num_stage_groups
)
{
should_load_a
=
true
;
}
else
{
should_load_a
=
false
;
}
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd_col
=
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
...
...
@@ -1900,12 +1850,10 @@ __global__ void Marlin(
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
...
...
csrc/moe/marlin_moe_wna16/ops.cu
View file @
1d0c9d6b
...
...
@@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups
*
pipe_stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
...
...
@@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
}
}
int
get_kernel_cache_size
(
thread_config_t
const
&
th_config
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
)
{
int
get_kernel_cache_size
(
thread_config_t
const
&
th_config
,
bool
m_block_size_8
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_m
=
thread_m_blocks
*
16
;
int
tb_m
=
thread_m_blocks
*
(
m_block_size_8
?
8
:
16
)
;
// shm size for block_sorted_ids/block_topk_weights
// shm size for
block_sorted_ids/rd_
block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int
sh_block_meta_size
=
tb_m
*
4
*
2
;
int
sh_block_meta_size
=
tb_m
*
4
;
int
sh_a_size
=
pipe_stages
*
(
tb_m
*
tb_k
)
*
2
;
int
sh_b_size
=
pipe_stages
*
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
int
sh_red_size
=
tb_m
*
(
tb_n
+
8
)
*
2
;
int
sh_s_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
...
...
@@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size
=
sh_s_size
/
2
;
}
int
total_size
=
sh_
a
_size
+
sh_
b
_size
+
sh_
s
_size
+
sh_
zp
_size
+
sh_g_idx_size
+
sh_block_meta_size
;
int
total_size
=
max
(
sh_
b
_size
,
sh_
red
_size
)
+
sh_
a
_size
+
sh_
s
_size
+
sh_zp_size
+
sh_g_idx_size
+
sh_block_meta_size
;
return
total_size
;
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
)
{
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
bool
m_block_size_8
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
...
...
@@ -266,143 +268,113 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
return
cache_size
<=
max_shared_mem
;
}
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template
<
typename
scalar_t
>
MarlinFuncPtr
get_marlin_kernel
(
const
vllm
::
ScalarType
q_type
,
...
...
@@ -415,23 +387,15 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto
kernel
=
MarlinDefault
;
if
(
false
)
{
}
GPTQ_GET_IF_M1
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_GET_IF_M1
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_GET_IF_M234
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_GET_IF_M234
(
vllm
::
kU4B8
,
8
,
4
,
128
)
COMMON_GET_IF
(
vllm
::
kU4
)
COMMON_GET_IF
(
vllm
::
kU4B8
)
COMMON_GET_IF
(
vllm
::
kU8B128
)
GPTQ_GET_IF_M1
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_GET_IF_M1
(
vllm
::
kU8B128
,
8
,
4
,
128
)
BIGGROUP_GET_IF
(
vllm
::
kFE4M3fn
)
GPTQ_GET_IF_M234
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_GET_IF_M234
(
vllm
::
kU8B128
,
8
,
4
,
128
)
AWQ_GET_IF_M1
(
vllm
::
kU4
,
8
,
8
,
256
)
AWQ_GET_IF_M1
(
vllm
::
kU4
,
8
,
4
,
128
)
AWQ_GET_IF_M234
(
vllm
::
kU4
,
16
,
4
,
256
)
AWQ_GET_IF_M234
(
vllm
::
kU4
,
8
,
4
,
128
)
ACT_GET_IF
(
vllm
::
kU4B8
)
ACT_GET_IF
(
vllm
::
kU8B128
)
return
kernel
;
}
...
...
@@ -457,19 +421,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for
(
int
i
=
0
;
i
<
thread_configs_size
;
i
++
)
{
thread_config_t
th_config
=
thread_configs
[
i
];
if
(
!
is_valid_config
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
))
{
if
(
!
is_valid_config
(
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
))
{
continue
;
}
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
int
group_blocks
=
0
;
if
(
!
has_act_order
)
{
group_blocks
=
group_size
==
-
1
?
-
1
:
group_size
/
16
;
group_blocks
=
group_size
==
-
1
?
-
1
:
(
group_size
/
16
)
;
}
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
...
...
@@ -515,14 +479,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
bool
m_block_size_8
=
moe_block_size
==
8
;
if
(
has_zp
)
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
TORCH_CHECK
(
q_type
==
vllm
::
kU4
,
"q_type must be u4 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
.
str
());
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
,
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
...
...
@@ -631,18 +595,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
)
,
"Invalid thread config: thread_m_blocks = "
,
th
re
a
d_m
_blocks
,
", thread_k
= "
,
thread_
tfg
.
thread_k
,
", thread_
n
= "
,
thread_tfg
.
thread_
n
,
",
num_
thread
s
= "
,
thread_tfg
.
num_
thread
s
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
",
"
,
prob_
n
,
"
] and num_bits = "
,
num_bits
,
",
g
ro
up_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
",
has_zp = "
,
has_zp
,
",
i
s_zp
_float
= "
,
i
s_zp
_float
,
", max_shared_mem = "
,
max_shared_mem
);
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_sha
red_m
em
)
,
"Invalid thread config: thread_m_blocks
= "
,
thread_
m_blocks
,
", thread_
k
= "
,
thread_tfg
.
thread_
k
,
", thread
_n
= "
,
thread_tfg
.
thread
_n
,
", num_threads = "
,
thread_tfg
.
num_threads
,
" for MKN = [
"
,
prob_
m
,
"
, "
,
prob_k
,
",
"
,
p
ro
b_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
",
is_k_full = "
,
is_k_full
,
",
ha
s_zp = "
,
ha
s_zp
,
", is_zp_float = "
,
is_zp_float
,
", max_shared_mem = "
,
max_shared_mem
);
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
q_type
,
thread_m_blocks
,
thread_n_blocks
,
thread_k_blocks
,
m_block_size_8
,
...
...
@@ -666,7 +630,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
zp_ptr
,
g_idx_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_past_padded_ptr
,
topk_weights_ptr
,
top_k
,
mul_topk_weights
,
is_ep
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
,
use_atomic_add
,
use_fp32_reduce
);
prob_n
,
prob_k
,
locks
,
use_atomic_add
,
use_fp32_reduce
,
max_shared_mem
);
// clang-format on
}
...
...
@@ -841,10 +805,11 @@ torch::Tensor moe_wna16_marlin_gemm(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
.
str
());
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
||
b_q_type
==
vllm
::
kFE4M3fn
,
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = "
,
b_q_type
.
str
());
}
if
(
has_zp
&&
is_zp_float
)
{
...
...
csrc/quantization/gptq_marlin/.gitignore
0 → 100644
View file @
1d0c9d6b
kernel_*.cu
\ No newline at end of file
csrc/quantization/gptq_marlin/dequant.h
0 → 100644
View file @
1d0c9d6b
#include "marlin_dtypes.cuh"
namespace
MARLIN_NAMESPACE_NAME
{
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
template
<
typename
scalar_t2
,
vllm
::
ScalarTypeId
w_type_id
>
__device__
inline
void
dequant
(
int
q
,
scalar_t2
*
frag_b
);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4B8
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4B8
.
id
()
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
// clang-format on
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4
.
id
()
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
// clang-format on
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8B128
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64006400
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU8B128
.
id
()
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388736.
f
;
fp32_intermediates
[
1
]
-=
8388736.
f
;
fp32_intermediates
[
2
]
-=
8388736.
f
;
fp32_intermediates
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU8
.
id
()
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388608.
f
;
fp32_intermediates
[
1
]
-=
8388608.
f
;
fp32_intermediates
[
2
]
-=
8388608.
f
;
fp32_intermediates
[
3
]
-=
8388608.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kFE4M3fn
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
// Constants for FP8 (E4M3) and FP16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
FP16_EXPONENT
=
5
;
constexpr
int
RIGHT_SHIFT
=
FP16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to FP16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
FP16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
const
half2
bias_reg
=
__float2half2_rn
(
float
(
1
<<
BIAS_OFFSET
));
// Convert to half2 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
),
bias_reg
);
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kFE4M3fn
.
id
()
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
// Constants for FP8 (E4M3) and BF16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
BF16_EXPONENT
=
8
;
constexpr
int
RIGHT_SHIFT
=
BF16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to BF16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
BF16_EXPONENT
-
1
))
-
(
1
<<
(
FP8_EXPONENT
-
1
));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr
uint32_t
BIAS
=
(
BIAS_OFFSET
+
127
)
<<
23
;
const
nv_bfloat162
bias_reg
=
__float2bfloat162_rn
(
*
reinterpret_cast
<
const
float
*>
(
&
BIAS
));
// Convert to bfloat162 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
),
bias_reg
);
}
#endif
}
// namespace MARLIN_NAMESPACE_NAME
csrc/quantization/gptq_marlin/generate_kernels.py
0 → 100644
View file @
1d0c9d6b
# SPDX-License-Identifier: Apache-2.0
import
glob
import
itertools
import
os
import
subprocess
import
jinja2
FILE_HEAD
=
"""
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
"""
.
strip
()
TEMPLATE
=
(
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
),
(
128
,
64
,
128
)]
THREAD_M_BLOCKS
=
[
0.5
,
1
,
2
,
3
,
4
]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS
=
[
0
,
-
1
,
2
,
4
,
8
]
DTYPES
=
[
"fp16"
,
"bf16"
]
def
remove_old_kernels
():
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/kernel_*.cu"
):
subprocess
.
call
([
"rm"
,
"-f"
,
filename
])
def
generate_new_kernels
():
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
all_template_str_list
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if
group_blocks
==
0
and
scalar_type
not
in
[
"vllm::kU4B8"
,
"vllm::kU8B128"
]:
continue
if
thread_configs
[
2
]
==
256
:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if
m_blocks
<=
1
and
thread_configs
[
0
]
!=
128
:
continue
if
m_blocks
>
1
and
thread_configs
[
0
]
!=
64
:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if
scalar_type
==
"vllm::kFE4M3fn"
and
group_blocks
not
in
[
-
1
,
8
]:
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
threads
=
thread_configs
[
2
]
c_dtype
=
"half"
if
dtype
==
"fp16"
else
"nv_bfloat16"
is_zp_float_list
=
[
False
]
if
dtype
==
"fp16"
and
scalar_type
==
"vllm::kU4"
and
\
group_blocks
==
4
:
# HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16
is_zp_float_list
.
append
(
True
)
for
is_zp_float
in
is_zp_float_list
:
template_str
=
jinja2
.
Template
(
TEMPLATE
).
render
(
scalar_t
=
c_dtype
,
w_type_id
=
scalar_type
+
".id()"
,
threads
=
threads
,
thread_m_blocks
=
max
(
m_blocks
,
1
),
thread_n_blocks
=
n_blocks
,
thread_k_blocks
=
k_blocks
,
m_block_size_8
=
m_blocks
==
0.5
,
stages
=
"pipe_stages"
,
group_blocks
=
group_blocks
,
is_zp_float
=
is_zp_float
,
)
all_template_str_list
.
append
(
template_str
)
file_content
=
FILE_HEAD
+
"
\n\n
"
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
filename
=
f
"kernel_
{
dtype
}
_
{
scalar_type
[
6
:].
lower
()
}
.cu"
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
),
"w"
)
as
f
:
f
.
write
(
file_content
)
if
__name__
==
"__main__"
:
remove_old_kernels
()
generate_new_kernels
()
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
1d0c9d6b
...
...
@@ -19,10 +19,11 @@
* Adapted from https://github.com/IST-DASLab/marlin
*/
#i
nclude "marlin.cuh"
#include "marlin_dtypes.cuh"
#
include "core/scalar_type.hpp"
#i
fndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#
endif
#include "kernel.h"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -30,13 +31,12 @@
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
marlin
{
__global__
void
MarlinDefault
(
MARLIN_KERNEL_PARAMS
){};
using
MarlinFuncPtr
=
void
(
*
)(
MARLIN_KERNEL_PARAMS
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
...
...
@@ -44,46 +44,17 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
lda
,
int
block_rows
)
{}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
is_zp_float
)
{
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
std
::
optional
<
torch
::
Tensor
>
c_or_none
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
g_idx_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
perm_or_none
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
&
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
...
...
@@ -91,369 +62,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
typename
scalar_t
>
__device__
inline
void
mma
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
typename
scalar_t
>
__device__
inline
void
ldsm4
(
typename
ScalarType
<
scalar_t
>::
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
template
<
typename
scalar_t
,
vllm
::
ScalarTypeId
w_type_id
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant
(
int
q
);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU4B8
.
id
()
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
vllm
::
kU4B8
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU4
.
id
()
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
vllm
::
kU4
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU8B128
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
vllm
::
kU8B128
.
id
()
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388736.
f
;
fp32_intermediates
[
1
]
-=
8388736.
f
;
fp32_intermediates
[
2
]
-=
8388736.
f
;
fp32_intermediates
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU8
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64006400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
vllm
::
kU8
.
id
()
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388608.
f
;
fp32_intermediates
[
1
]
-=
8388608.
f
;
fp32_intermediates
[
2
]
-=
8388608.
f
;
fp32_intermediates
[
3
]
-=
8388608.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
__device__
inline
void
scale
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
template
<
typename
scalar_t
>
__device__
inline
void
sub_zp
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
scalar_t2
&
frag_zp
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
zp
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_zp
)[
i
]);
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
zp
);
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
zp
);
}
// Same as above, but for act_order (each K is multiplied individually)
template
<
typename
scalar_t
>
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_1
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_2
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_3
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_4
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s_val_1_2
;
s_val_1_2
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_1
)[
i
];
s_val_1_2
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_2
)[
i
];
scalar_t2
s_val_3_4
;
s_val_3_4
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_3
)[
i
];
s_val_3_4
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_4
)[
i
];
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
// Given 2 floats multiply by 2 scales (halves)
template
<
typename
scalar_t
>
__device__
inline
void
scale_float
(
float
*
c
,
typename
ScalarType
<
scalar_t
>::
FragS
&
s
)
{
scalar_t
*
s_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
...
...
@@ -510,1304 +118,19 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
}
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
lda
,
// A.stride(0), equal to prob_k is A is contiguous
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using
Dtype
=
ScalarType
<
scalar_t
>
;
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
using
FragA
=
typename
ScalarType
<
scalar_t
>::
FragA
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragZP
=
typename
ScalarType
<
scalar_t
>::
FragZP
;
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
16
*
thread_m_blocks
)
{
parallel
=
prob_m
/
(
16
*
thread_m_blocks
);
prob_m
=
16
*
thread_m_blocks
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
div_ceil
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
if
constexpr
(
!
has_act_order
&&
group_blocks
!=
-
1
)
{
if
(
group_blocks
>=
thread_k_blocks
)
{
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters
=
(
group_blocks
/
thread_k_blocks
)
*
div_ceil
(
iters
,
(
group_blocks
/
thread_k_blocks
));
}
}
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
int
par_id
=
0
;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
lda
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
par_id
=
slice_col_par
/
n_tiles
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
]()
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
div_ceil
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
div_ceil
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
lda
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
par_id
++
;
}
};
init_slice
();
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
lda
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
(
16
*
thread_m_blocks
);
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
w_type
.
size_bits
()
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Zero-points sizes/strides
int
zp_gl_stride
=
is_zp_float
?
prob_n
/
8
:
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
is_zp_float
?
16
*
thread_n_blocks
/
8
:
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
constexpr
int
zp_tb_groups
=
s_tb_groups
;
constexpr
int
zp_sh_stage
=
has_zp
?
zp_tb_groups
*
zp_sh_stride
:
0
;
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
auto
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
;
if
constexpr
(
!
has_act_order
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
auto
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
int
zp_gl_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
zp_gl_rd
=
zp_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
auto
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr
int
num_col_threads
=
8
;
constexpr
int
num_row_threads
=
4
;
constexpr
int
num_ints_per_thread
=
8
/
pack_factor
;
int
zp_sh_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
is_zp_float
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
zp_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
}
}
else
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
}
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_red
=
sh_s
+
(
stages
*
s_sh_stage
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
int
frag_qzp
[
2
][
num_ints_per_thread
];
// Zero-points
FragZP
frag_zp
;
// Zero-points in fp16
FragZP
frag_zpf
[
2
];
// Zero-points in fp16 in HQQ
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh_max_num_groups
)
{
sh_num_groups
=
sh_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
if
constexpr
(
has_act_order
)
{
// Fetch g_idx thread-block portion
int
full_pipe
=
a_off
;
int
cur_k
=
slice_k_start_shared_fetch
+
tb_k
*
full_pipe
;
if
(
cur_k
<
prob_k
&&
cur_k
<
slice_k_finish
)
{
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int4
const
*
cur_g_idx_stage_ptr
=
reinterpret_cast
<
int4
const
*>
(
&
g_idx
[
cur_k
]);
if
(
threadIdx
.
x
<
g_idx_stage
)
{
cp_async4_pred
(
&
sh_g_idx_stage
[
threadIdx
.
x
],
&
cur_g_idx_stage_ptr
[
threadIdx
.
x
]);
}
}
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
// Only fetch scales if this tile starts a new group
if
((
pipe
+
1
)
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
s_gl_rd
+=
s_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch zero-points if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
zp_tb_groups
;
i
++
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
i
*
zp_sh_stride
+
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
auto
fetch_zp_to_shared
=
[
&
]()
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
if
constexpr
(
!
has_act_order
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
int
group_id_1
=
sh_g_idx_int_ptr
[
0
];
int
group_id_2
=
sh_g_idx_int_ptr
[
tb_k
-
1
];
is_same_group
[
pipe
]
=
group_id_1
==
group_id_2
;
same_group_id
[
pipe
]
=
group_id_1
;
};
auto
fetch_scales_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
!
has_act_order
)
{
// No act-order case
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
}
return
;
}
// Act-order case
// Determine K of the "current" thread-block
int
cur_k
=
slice_k_start
+
tb_k
*
full_pipe
;
if
(
cur_k
>=
prob_k
||
cur_k
>=
slice_k_finish
)
{
return
;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k
=
0
;
// Progress to current iteration
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
int
warp_row
=
warp_id
/
n_warps
;
int
warp_col
=
warp_id
%
n_warps
;
cur_k
+=
warp_row
*
16
;
auto
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
/*slice_n_offset +*/
(
act_s_col_warp_stride
*
warp_col
)
+
(
th_id
/
4
)
*
act_s_col_stride
;
if
(
is_same_group
[
pipe
])
{
if
(
k
%
2
==
0
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
sh_s
[(
same_group_id
[
pipe
]
-
sh_first_group_id
)
*
s_sh_stride
+
s_col_shift
];
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[(
k
-
1
)
%
2
][
0
][
0
])));
}
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])));
}
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
constexpr
int
k_frag_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
// Tensor core offsets per thread
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
actual_k
=
cur_k
+
k_frag_offsets
[
i
];
int
group_id
=
sh_g_idx_int_ptr
[
actual_k
];
int
rel_group_id
=
group_id
-
sh_first_group_id
;
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
sh_s
[
rel_group_id
*
s_sh_stride
+
s_col_shift
];
}
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
0
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
if
constexpr
(
has_zp
&&
is_zp_float
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
}
else
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
+
cur_group_id
*
zp_sh_stride
];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
FragB
frag_zp_0
;
FragB
frag_zp_1
;
int
zp_quant_0
,
zp_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
zp_quant_1
=
zp_quant_0
>>
8
;
}
else
{
static_assert
(
w_type
.
size_bits
()
==
8
);
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
zp_quant_1
=
frag_qzp
[
k
%
2
][
1
];
}
frag_zp_0
=
dequant
<
scalar_t
,
w_type_id
>
(
zp_quant_0
);
frag_zp_1
=
dequant
<
scalar_t
,
w_type_id
>
(
zp_quant_1
);
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k
%
2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
static_assert
(
w_type
.
size_bits
()
==
8
);
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
}
frag_b0
=
dequant
<
scalar_t
,
w_type_id
>
(
b_quant_0
);
frag_b1
=
dequant
<
scalar_t
,
w_type_id
>
(
b_quant_1
);
// Apply zero-point to frag_b0
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
}
else
if
constexpr
(
has_zp
&&
is_zp_float
&&
group_blocks
!=
-
1
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zpf
[
k
%
2
][
j
],
0
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
0
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
}
}
// Apply zero-point to frag_b1
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
is_zp_float
&&
group_blocks
!=
-
1
)
{
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zpf
[
k
%
2
][
j
],
1
);
}
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
1
);
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
++
)
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh_red
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
++
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce_fp16
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
auto
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
cp_async4_pred
(
&
sh_red
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
;
i
++
)
{
if
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
{
if
(
!
first
)
{
int4
c_red
=
sh_red
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
Dtype
::
num2float
(
reinterpret_cast
<
scalar_t
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
scalar_t
*>
(
&
c
)[
j
]
=
Dtype
::
float2num
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto
global_reduce_fp32
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
constexpr
int
tb_m
=
thread_m_blocks
*
16
;
constexpr
int
tb_n
=
thread_n_blocks
*
16
;
constexpr
int
c_size
=
tb_m
*
tb_n
*
sizeof
(
float
)
/
16
;
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
bool
is_th_active
=
threadIdx
.
x
<
active_threads
;
int
par_offset
=
c_size
*
n_tiles
*
par_id
;
int
slice_offset
=
c_size
*
slice_col
;
constexpr
int
num_floats
=
thread_m_blocks
*
4
*
2
*
4
;
constexpr
int
th_size
=
num_floats
*
sizeof
(
float
)
/
16
;
int
c_cur_offset
=
par_offset
+
slice_offset
;
if
(
!
is_th_active
)
{
return
;
}
if
(
!
first
)
{
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh_red
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
}
}
}
if
(
!
last
)
{
int4
*
frag_c_ptr
=
reinterpret_cast
<
int4
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
]
=
frag_c_ptr
[
k
];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
scalar_t2
res
=
Dtype
::
nums2num2
(
Dtype
::
float2num
(
c0
),
Dtype
::
float2num
(
c1
));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
((
scalar_t2
*
)
sh_red
)[
idx
]
=
res
;
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
if
(
use_atomic_add
&&
slice_count
>
1
)
{
scalar_t2
*
C_half2
=
reinterpret_cast
<
scalar_t2
*>
(
&
C
[
c_gl_wr
]);
scalar_t2
*
sh_red_half2
=
reinterpret_cast
<
scalar_t2
*>
(
&
sh_red
[
c_sh_rd
]);
#pragma unroll
for
(
int
a
=
0
;
a
<
4
;
a
++
)
{
atomicAdd
(
&
C_half2
[
a
],
sh_red_half2
[
a
]);
}
}
else
{
C
[
c_gl_wr
]
=
sh_red
[
c_sh_rd
];
}
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
if
(
has_act_order
&&
i
==
0
)
{
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_zp_to_shared
();
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
w_type
.
size_bits
()
==
8
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
else
{
if
(
last
||
use_atomic_add
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
}
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
w_type
.
size_bits
()
==
8
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
else
{
if
(
last
||
use_atomic_add
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
}
if
(
slice_count
>
1
&&
!
use_atomic_add
)
{
// only globally reduce if there is more than one block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
}
else
{
global_reduce_fp16
(
slice_idx
==
0
,
last
);
}
barrier_release
(
&
locks
[
slice_col
],
last
);
}
if
(
last
||
use_atomic_add
)
// only the last block in a slice actuallywrites the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
if
constexpr
(
has_act_order
)
{
slice_k_start
=
tb_k
*
slice_row
;
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
}
}
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, lda, locks, \
part_use_atomic_add, use_fp32_reduce); \
} \
}
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
},
};
{
128
,
64
,
128
}};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
...
...
@@ -1815,9 +138,12 @@ thread_config_t large_batch_thread_configs[] = {
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
{
64
,
128
,
128
},
{
128
,
64
,
128
}
,
{
128
,
64
,
128
}
};
};
typedef
struct
{
int
blocks_per_sm
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
...
...
@@ -1842,7 +168,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups
*
pipe_stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
...
...
@@ -1850,49 +175,43 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
}
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
get_kernel_cache_size
(
thread_config_t
const
&
th_config
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
div_ceil
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
int
tb_m
=
thread_m_blocks
*
16
;
int
sh_a_size
=
pipe_stages
*
(
tb_m
*
tb_k
)
*
2
;
int
sh_b_size
=
pipe_stages
*
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
int
sh_red_size
=
tb_m
*
(
tb_n
+
8
);
int
sh_s_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
int
sh_g_idx_size
=
has_act_order
&&
!
is_k_full
?
pipe_stages
*
tb_k
/
4
:
0
;
int
sh_zp_size
=
0
;
if
(
has_zp
)
{
if
(
is_zp_float
)
sh_zp_size
=
sh_s_size
;
else
if
(
num_bits
==
4
)
sh_zp_size
=
sh_s_size
/
4
;
else
if
(
num_bits
==
8
)
sh_zp_size
=
sh_s_size
/
2
;
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
float
reduce_size
=
max
(
th_config
.
num_threads
*
32
*
4
,
(
tb_n
/
64
)
*
32
*
(
tb_max_m
/
16
)
*
4
*
2
*
4
*
2
);
int
total_size
=
max
(
sh_b_size
,
sh_red_size
)
+
sh_a_size
+
sh_s_size
+
sh_zp_size
+
sh_g_idx_size
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
+
reduce_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
return
total_size
;
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max
_m_blocks
,
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
thread
_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
...
...
@@ -1914,163 +233,204 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
return
cache_size
<=
max_shared_mem
;
}
int
determine_reduce_max_m
(
int
prob_m
,
int
max_par
)
{
constexpr
int
tile_m_size
=
16
;
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 4, 8, 128)
if
(
prob_m
<=
tile_m_size
)
{
return
tile_m_size
;
template
<
typename
scalar_t
>
MarlinFuncPtr
get_marlin_kernel
(
const
vllm
::
ScalarType
q_type
,
int
thread_m_blocks
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
m_block_size_8
,
bool
has_act_order
,
bool
has_zp
,
int
group_blocks
,
int
num_threads
,
bool
is_zp_float
)
{
int
num_bits
=
q_type
.
size_bits
();
auto
kernel
=
MarlinDefault
;
if
(
false
)
{
}
}
else
if
(
prob_m
<=
tile_m_size
*
2
)
{
return
tile_m_size
*
2
;
COMMON_GET_IF
(
vllm
::
kU4
)
COMMON_GET_IF
(
vllm
::
kU4B8
)
COMMON_GET_IF
(
vllm
::
kU8B128
)
}
else
if
(
prob_m
<=
tile_m_size
*
3
)
{
return
tile_m_size
*
3
;
BIGGROUP_GET_IF
(
vllm
::
kFE4M3fn
)
}
else
if
(
prob_m
<=
tile_m_size
*
4
)
{
return
tile_m_size
*
4
;
ACT_GET_IF
(
vllm
::
kU4B8
)
ACT_GET_IF
(
vllm
::
kU8B128
)
}
else
{
int
cur_par
=
min
(
div_ceil
(
prob_m
,
tile_m_size
*
4
),
max_par
);
return
tile_m_size
*
4
*
cur_par
;
if
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
if
(
false
)
{
}
FZP_GET_IF
(
vllm
::
kU4
)
}
return
kernel
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
template
<
typename
scalar_t
>
exec_config_t
determine_exec_config
(
const
vllm
::
ScalarType
&
q_type
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
thread_m_blocks
,
bool
m_block_size_8
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
bool
is_zp_float
,
int
max_shared_mem
,
int
sms
)
{
exec_config_t
exec_cfg
=
exec_config_t
{
1
,
thread_config_t
{
-
1
,
-
1
,
-
1
}};
thread_config_t
*
thread_configs
=
thread_m_blocks
>
1
?
large_batch_thread_configs
:
small_batch_thread_configs
;
int
thread_configs_size
=
thread_m_blocks
>
1
?
sizeof
(
large_batch_thread_configs
)
/
sizeof
(
thread_config_t
)
:
sizeof
(
small_batch_thread_configs
)
/
sizeof
(
thread_config_t
);
for
(
int
i
=
0
;
i
<
thread_configs_size
;
i
++
)
{
thread_config_t
th_config
=
thread_configs
[
i
];
if
(
!
is_valid_config
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
))
{
continue
;
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
int
group_blocks
=
0
;
if
(
!
has_act_order
)
{
group_blocks
=
group_size
==
-
1
?
-
1
:
group_size
/
16
;
}
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false)
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
q_type
,
thread_m_blocks
,
th_config
.
thread_n
/
16
,
th_config
.
thread_k
/
16
,
m_block_size_8
,
has_act_order
,
has_zp
,
group_blocks
,
th_config
.
num_threads
,
is_zp_float
);
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
if
(
kernel
==
MarlinDefault
)
continue
;
// int m_tiles = div_ceil(prob_m, thread_m_blocks * 16);
// int n_tiles = prob_n / th_config.thread_n;
// int k_tiles = prob_k / th_config.thread_k;
return
{
1
,
th_config
};
}
return
exec_cfg
;
}
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
...
...
@@ -2078,78 +438,24 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int
prob_n
,
int
prob_k
,
int
lda
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
int
dev
,
cudaStream_t
stream
,
int
thread_k
_init
,
int
thread_n_init
,
int
sms
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
if
(
has_zp
)
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
.
str
());
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
,
"q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
"has_zp = False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
// TODO: remove alias when we start supporting other 8bit types
int
num_bits
=
q_type
.
size_bits
();
int
tot_m
=
prob_m
;
int
tot_m_blocks
=
div_ceil
(
tot_m
,
16
);
int
pad
=
16
*
tot_m_blocks
-
tot_m
;
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
default_threads
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
0
;
if
(
has_act_order
)
{
if
(
is_k_full
)
{
...
...
@@ -2161,7 +467,6 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
TORCH_CHECK
(
group_size
==
0
);
group_blocks
=
0
;
}
}
else
{
if
(
group_size
==
-
1
)
{
group_blocks
=
-
1
;
...
...
@@ -2172,6 +477,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
}
}
int
num_bits
=
q_type
.
size_bits
();
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
...
...
@@ -2186,106 +492,138 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
if
(
has_act_order
)
{
// Permute A columns
int
block_rows
=
div_ceil
(
prob_m
,
blocks
);
permute_cols_kernel
<<<
blocks
,
default_threads
,
0
,
stream
>>>
(
int
block_rows
=
div_ceil
(
prob_m
,
sms
);
// avoid ">>>" being formatted to "> > >"
// clang-format off
permute_cols_kernel
<<<
sms
,
default_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
prob_m
,
prob_k
,
lda
,
block_rows
);
// clang-format on
A_ptr
=
a_tmp_ptr
;
lda
=
prob_k
;
}
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by having
// a full K, we have full original groups)
if
(
is_k_full
)
{
has_act_order
=
false
;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if
(
is_k_full
)
has_act_order
=
false
;
}
// Main loop
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
exec_cfg
.
max_m_blocks
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
if
(
thread_m_blocks
>
exec_cfg
.
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
(
16
*
exec_cfg
.
max_m_blocks
);
if
(
par
>
max_par
)
par
=
max_par
;
prob_m
=
(
16
*
exec_cfg
.
max_m_blocks
)
*
par
;
i
+=
exec_cfg
.
max_m_blocks
*
(
par
-
1
);
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
int
max_par
=
16
;
if
(
prob_n
<=
4096
)
max_par
=
16
*
8
;
int
max_shared_mem_new
=
max_shared_mem
;
int
rest_m
=
prob_m
;
int
max_thread_m_blocks
=
4
;
while
(
rest_m
)
{
int
par_count
=
rest_m
/
(
max_thread_m_blocks
*
16
);
if
(
par_count
>
max_par
)
par_count
=
max_par
;
int
prob_m_split
=
par_count
>
0
?
(
par_count
*
(
max_thread_m_blocks
*
16
))
:
rest_m
;
int
thread_k
=
thread_k_init
;
int
thread_n
=
thread_n_init
;
int
thread_m_blocks
=
min
(
div_ceil
(
prob_m_split
,
16
),
max_thread_m_blocks
);
int
m_block_size_8
=
prob_m_split
<=
8
;
// Set thread config
exec_config_t
exec_cfg
;
thread_config_t
thread_tfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
thread_tfg
=
thread_config_t
{
thread_k
,
thread_n
,
default_threads
};
exec_cfg
=
exec_config_t
{
1
,
thread_tfg
};
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
}
else
{
// Auto config
exec_cfg
=
determine_exec_config
<
scalar_t
>
(
q_type
,
prob_m_split
,
prob_n
,
prob_k
,
thread_m_blocks
,
m_block_size_8
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
,
sms
);
thread_tfg
=
exec_cfg
.
tb_cfg
;
if
(
thread_tfg
.
thread_k
==
-
1
&&
max_thread_m_blocks
>
1
)
{
max_thread_m_blocks
--
;
continue
;
}
}
// atomic add reduce have better performance only when m * n is small
bool
part_use_atomic_add
=
use_atomic_add
&&
div_ceil
(
prob_m
,
64
)
*
prob_n
<=
2048
;
int
num_threads
=
thread_tfg
.
num_threads
;
thread_k
=
thread_tfg
.
thread_k
;
thread_n
=
thread_tfg
.
thread_n
;
int
blocks
=
sms
*
exec_cfg
.
blocks_per_sm
;
if
(
exec_cfg
.
blocks_per_sm
>
1
)
max_shared_mem_new
=
max_shared_mem
/
exec_cfg
.
blocks_per_sm
-
1024
;
if
(
false
)
{
}
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU8B128
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
vllm
::
kU8B128
,
4
,
8
,
128
)
AWQ_CALL_IF
(
vllm
::
kU4
,
16
,
4
,
256
)
AWQ_CALL_IF
(
vllm
::
kU4
,
8
,
8
,
256
)
AWQ_CALL_IF
(
vllm
::
kU4
,
8
,
4
,
128
)
AWQ_CALL_IF
(
vllm
::
kU4
,
4
,
8
,
128
)
AWQ_CALL_IF
(
vllm
::
kU8
,
16
,
4
,
256
)
AWQ_CALL_IF
(
vllm
::
kU8
,
8
,
8
,
256
)
AWQ_CALL_IF
(
vllm
::
kU8
,
8
,
4
,
128
)
AWQ_CALL_IF
(
vllm
::
kU8
,
4
,
8
,
128
)
HQQ_CALL_IF
(
vllm
::
kU4
,
16
,
4
,
256
)
HQQ_CALL_IF
(
vllm
::
kU4
,
8
,
8
,
256
)
HQQ_CALL_IF
(
vllm
::
kU4
,
8
,
4
,
128
)
HQQ_CALL_IF
(
vllm
::
kU4
,
4
,
8
,
128
)
else
{
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
thread_m_blocks
,
prob_m_split
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem_new
),
"Invalid thread config: thread_m_blocks = "
,
thread_m_blocks
,
", thread_k = "
,
thread_tfg
.
thread_k
,
", thread_n = "
,
thread_tfg
.
thread_n
,
", num_threads = "
,
thread_tfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", prob_m_split = "
,
prob_m_split
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", has_zp = "
,
has_zp
,
", is_zp_float = "
,
is_zp_float
,
", max_shared_mem_new = "
,
max_shared_mem_new
);
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
q_type
,
thread_m_blocks
,
thread_n_blocks
,
thread_k_blocks
,
m_block_size_8
,
has_act_order
,
has_zp
,
group_blocks
,
num_threads
,
is_zp_float
);
if
(
kernel
==
MarlinDefault
)
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
,
", has_act_order = "
,
has_act_order
,
", num_groups = "
,
num_groups
,
", group_size = "
,
group_size
,
", prob_m_split = "
,
prob_m_split
,
", thread_m_blocks = "
,
thread_m_blocks
,
", thread_n_blocks = "
,
thread_n_blocks
,
", thread_k_blocks = "
,
thread_k_blocks
,
", num_bits = "
,
num_bits
);
",
num_threads = "
,
num_threads
,
",
num_bits = "
,
num_bits
);
}
A_ptr
+=
16
*
thread_m_blocks
*
(
lda
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
max_shared_mem_new
);
bool
part_use_atomic_add
=
use_atomic_add
&&
div_ceil
(
prob_m_split
,
64
)
*
prob_n
<=
2048
;
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel
<<<
blocks
,
num_threads
,
max_shared_mem_new
,
stream
>>>
(
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
zp_ptr
,
g_idx_ptr
,
num_groups
,
prob_m_split
,
prob_n
,
prob_k
,
lda
,
locks
,
part_use_atomic_add
,
use_fp32_reduce
,
max_shared_mem_new
);
// clang-format on
A_ptr
+=
prob_m_split
*
(
lda
/
8
);
C_ptr
+=
prob_m_split
*
(
prob_n
/
8
);
rest_m
-=
prob_m_split
;
}
}
}
// namespace marlin
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
&
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
std
::
optional
<
torch
::
Tensor
>
c_or_none
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
g_idx_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
perm_or_none
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
&
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
||
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
.
str
());
}
if
(
has_zp
&&
is_zp_float
)
{
TORCH_CHECK
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Computation type must be float16 (half) when using float zero "
"points."
);
}
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
// Verify A
...
...
@@ -2295,15 +633,19 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = "
,
size_k
);
// Verify B
TORCH_CHECK
(
size_k
%
marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
((
size_k
/
marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
(
size_k
%
MARLIN_NAMESPACE_NAME
::
tile_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_size = "
,
MARLIN_NAMESPACE_NAME
::
tile_size
);
TORCH_CHECK
((
size_k
/
MARLIN_NAMESPACE_NAME
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", tile_size = "
,
marlin
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
marlin
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
marlin
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
marlin
::
tile_size
)
*
pack_factor
;
", size_k = "
,
size_k
,
", tile_size = "
,
MARLIN_NAMESPACE_NAME
::
tile_size
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
MARLIN_NAMESPACE_NAME
::
tile_size
==
0
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not divisible by tile_size = "
,
MARLIN_NAMESPACE_NAME
::
tile_size
);
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
MARLIN_NAMESPACE_NAME
::
tile_size
)
*
pack_factor
;
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
", actual_size_n = "
,
actual_size_n
);
...
...
@@ -2320,63 +662,47 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK
(
b_scales
.
device
().
is_cuda
(),
"b_scales is not on GPU"
);
TORCH_CHECK
(
b_scales
.
is_contiguous
(),
"b_scales is not contiguous"
);
TORCH_CHECK
(
b_zeros
.
device
().
is_cuda
(),
"b_zeros is not on GPU"
);
TORCH_CHECK
(
b_zeros
.
is_contiguous
(),
"b_zeros is not contiguous"
);
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
TORCH_CHECK
(
perm
.
device
().
is_cuda
(),
"perm is not on GPU"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel
int
sms
=
-
1
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
a
.
get_device
());
// Alloc buffers
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
torch
::
Tensor
c
;
if
(
use_atomic_add
)
{
c
=
torch
::
zeros
({
size_m
,
size_n
},
options
);
if
(
c_or_none
.
has_value
())
{
c
=
c_or_none
.
value
();
TORCH_CHECK
(
c
.
device
().
is_cuda
(),
"c is not on GPU"
);
TORCH_CHECK
(
c
.
is_contiguous
(),
"c is not contiguous"
);
TORCH_CHECK
(
c
.
size
(
0
)
==
size_m
,
"Shape mismatch: c.size(0) = "
,
c
.
size
(
0
),
", size_m = "
,
size_m
);
TORCH_CHECK
(
c
.
size
(
1
)
==
size_n
,
"Shape mismatch: c.size(1) = "
,
c
.
size
(
1
),
", size_n = "
,
size_n
);
}
else
{
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
}
torch
::
Tensor
a_tmp
;
bool
has_act_order
=
g_idx
.
size
(
0
)
!=
0
;
if
(
has_act_order
)
{
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
}
else
{
a_tmp
=
torch
::
empty
({
0
},
options
);
}
if
(
size_m
==
0
)
return
c
;
// Alloc C tmp buffer that is going to be used for the global reduce
torch
::
Tensor
c_tmp
;
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_n
=
size_n
;
auto
options_fp32
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
if
(
use_fp32_reduce
)
{
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
int
max_m_block_size
=
(
size_m
+
16
-
1
)
/
16
*
16
;
max_m_block_size
=
min
(
max_m_block_size
,
64
);
int
max_c_tmp_size
=
sms
*
max_m_block_size
*
MARLIN_NAMESPACE_NAME
::
max_thread_n
;
c_tmp
=
torch
::
empty
({
max_c_tmp_size
},
options_fp32
);
}
else
{
reduce_max_m
=
0
;
reduce_n
=
0
;
c_tmp
=
torch
::
empty
({
0
},
options_fp32
);
}
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Verify g_idx and perm
TORCH_CHECK
((
g_idx
.
size
(
0
)
==
0
&&
perm
.
size
(
0
)
==
0
)
||
(
g_idx
.
size
(
0
)
==
size_k
&&
perm
.
size
(
0
)
==
size_k
),
"Unexpected g_idx.size(0) = "
,
g_idx
.
size
(
0
),
" and perm.size(0) = "
,
perm
.
size
(
0
),
", where size_k = "
,
size_k
);
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
...
...
@@ -2387,7 +713,31 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
0
);
torch
::
Tensor
g_idx
,
perm
,
a_tmp
;
if
(
g_idx_or_none
.
has_value
()
&&
perm_or_none
.
has_value
())
{
g_idx
=
g_idx_or_none
.
value
();
perm
=
perm_or_none
.
value
();
TORCH_CHECK
(
g_idx
.
device
().
is_cuda
(),
"g_idx is not on GPU"
);
TORCH_CHECK
(
g_idx
.
is_contiguous
(),
"g_idx is not contiguous"
);
TORCH_CHECK
(
perm
.
device
().
is_cuda
(),
"perm is not on GPU"
);
TORCH_CHECK
(
perm
.
is_contiguous
(),
"perm is not contiguous"
);
// Verify g_idx and perm
TORCH_CHECK
((
g_idx
.
size
(
-
1
)
==
0
&&
perm
.
size
(
-
1
)
==
0
)
||
(
g_idx
.
size
(
-
1
)
==
size_k
&&
perm
.
size
(
-
1
)
==
size_k
),
"Unexpected g_idx.size(-1) = "
,
g_idx
.
size
(
-
1
),
" and perm.size(-1) = "
,
perm
.
size
(
-
1
),
", where size_k = "
,
size_k
);
}
else
{
g_idx
=
torch
::
empty
({
0
},
options
);
perm
=
torch
::
empty
({
0
},
options
);
a_tmp
=
torch
::
empty
({
0
},
options
);
}
bool
has_act_order
=
g_idx
.
size
(
-
1
)
>
0
&&
perm
.
size
(
-
1
)
>
0
;
if
(
has_act_order
)
{
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
...
...
@@ -2398,6 +748,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
else
{
a_tmp
=
torch
::
empty
({
0
},
options
);
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
...
...
@@ -2408,6 +759,33 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
torch
::
Tensor
b_zeros
;
if
(
b_zeros_or_none
.
has_value
())
{
b_zeros
=
b_zeros_or_none
.
value
();
TORCH_CHECK
(
b_zeros
.
device
().
is_cuda
(),
"b_zeros is not on GPU"
);
TORCH_CHECK
(
b_zeros
.
is_contiguous
(),
"b_zeros is not contiguous"
);
}
else
{
b_zeros
=
torch
::
empty
({
0
},
options
);
}
bool
has_zp
=
b_zeros
.
size
(
-
1
)
>
0
;
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
||
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
||
b_q_type
==
vllm
::
kFE4M3fn
,
"b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
"has_zp = False. Got = "
,
b_q_type
.
str
());
}
if
(
has_zp
&&
is_zp_float
)
{
TORCH_CHECK
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Computation type must be float16 (half) when using float zero "
"points."
);
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
...
...
@@ -2431,9 +809,11 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
// Verify workspace size
TORCH_CHECK
(
size_n
%
marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
marlin
::
min_thread_n
);
int
min_workspace_size
=
(
size_n
/
marlin
::
min_thread_n
)
*
marlin
::
max_par
;
TORCH_CHECK
(
size_n
%
MARLIN_NAMESPACE_NAME
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
,
MARLIN_NAMESPACE_NAME
::
min_thread_n
);
int
min_workspace_size
=
sms
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
,
workspace
.
numel
(),
" is below min_workspace_size = "
,
min_workspace_size
);
...
...
@@ -2447,8 +827,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
a
.
stride
(
0
),
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
thread_k
,
thread_n
,
sms
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
...
...
@@ -2458,7 +837,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a
.
stride
(
0
),
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
...
...
csrc/quantization/gptq_marlin/kernel.h
0 → 100644
View file @
1d0c9d6b
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
int prob_k, int lda, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace
MARLIN_NAMESPACE_NAME
{
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
bool
m_block_size_8
,
// whether m_block_size == 8
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
MARLIN_KERNEL_PARAMS
);
}
csrc/quantization/gptq_marlin/marlin_template.h
0 → 100644
View file @
1d0c9d6b
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "dequant.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
namespace
MARLIN_NAMESPACE_NAME
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
bool
m_block_size_8
,
// whether m_block_size == 8
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
}
// namespace marlin
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
typename
scalar_t
>
__device__
inline
void
mma
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
template
<
typename
scalar_t
>
__device__
inline
void
mma_trans
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b2
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
const
uint32_t
*
b2
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b2
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
int
count
,
typename
scalar_t
>
__device__
inline
void
ldsm
(
typename
ScalarType
<
scalar_t
>::
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
if
constexpr
(
count
==
4
)
{
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
]),
"=r"
(
a
[
2
]),
"=r"
(
a
[
3
])
:
"r"
(
smem
));
}
else
if
constexpr
(
count
==
2
)
{
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];
\n
"
:
"=r"
(
a
[
0
]),
"=r"
(
a
[
1
])
:
"r"
(
smem
));
}
else
if
constexpr
(
count
==
1
)
{
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
a
[
0
])
:
"r"
(
smem
));
}
else
{
static_assert
(
count
==
1
||
count
==
2
||
count
==
4
,
"invalid count"
);
}
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
__device__
inline
void
scale
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
template
<
typename
scalar_t
>
__device__
inline
void
scale_and_sub
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
scalar_t
s
,
scalar_t
zp
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s2
=
ScalarType
<
scalar_t
>::
num2num2
(
s
);
scalar_t2
zp2
=
ScalarType
<
scalar_t
>::
num2num2
(
zp
);
frag_b
[
0
]
=
__hfma2
(
frag_b
[
0
],
s2
,
__hneg2
(
zp2
));
frag_b
[
1
]
=
__hfma2
(
frag_b
[
1
],
s2
,
__hneg2
(
zp2
));
}
template
<
typename
scalar_t
>
__device__
inline
void
sub_zp
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
scalar_t2
&
frag_zp
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
zp
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_zp
)[
i
]);
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
zp
);
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
zp
);
}
// Same as above, but for act_order (each K is multiplied individually)
template
<
typename
scalar_t
>
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_1
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_2
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_3
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_4
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s_val_1_2
;
s_val_1_2
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_1
)[
i
];
s_val_1_2
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_2
)[
i
];
scalar_t2
s_val_3_4
;
s_val_3_4
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_3
)[
i
];
s_val_3_4
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_4
)[
i
];
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
// Given 2 floats multiply by 2 scales (halves)
template
<
typename
scalar_t
>
__device__
inline
void
scale_float
(
float
*
c
,
typename
ScalarType
<
scalar_t
>::
FragS
&
s
)
{
scalar_t
*
s_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
// Wait until value of lock to be negative, and then add 1
__device__
inline
void
wait_negative_and_add
(
int
*
lock
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
0
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
>=
0
);
atomicAdd
(
lock
,
1
);
}
__syncthreads
();
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
bool
m_block_size_8
,
// whether m_block_size == 8
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
lda
,
// A.stride(0), equal to prob_k is A is contiguous
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
,
// whether to use fp32 global reduce
int
max_shared_mem
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using
Dtype
=
ScalarType
<
scalar_t
>
;
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
using
FragA
=
typename
ScalarType
<
scalar_t
>::
FragA
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragZP
=
typename
ScalarType
<
scalar_t
>::
FragZP
;
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
bool
has_zp
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
;
constexpr
bool
has_act_order
=
group_blocks
==
0
;
constexpr
int
m_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
static_assert
(
thread_m_blocks
==
1
||
!
m_block_size_8
);
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int
parallel
=
1
;
if
(
prob_m
>
m_block_size
)
{
parallel
=
prob_m
/
m_block_size
;
prob_m
=
m_block_size
;
}
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
int
iters
=
div_ceil
(
k_tiles
*
n_tiles
*
parallel
,
gridDim
.
x
);
if
constexpr
(
!
has_act_order
&&
group_blocks
!=
-
1
)
{
if
(
group_blocks
>=
thread_k_blocks
)
{
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters
=
(
group_blocks
/
thread_k_blocks
)
*
div_ceil
(
iters
,
(
group_blocks
/
thread_k_blocks
));
}
}
int
slice_row
=
(
iters
*
blockIdx
.
x
)
%
k_tiles
;
int
slice_col_par
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
;
int
slice_col
=
slice_col_par
;
int
slice_iters
;
// number of threadblock tiles in the current slice
int
slice_count
=
0
;
// total number of active threadblocks in the current slice
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
int
par_id
=
0
;
int
locks_off
=
0
;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
lda
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
slice_col
=
slice_col_par
%
n_tiles
;
par_id
=
slice_col_par
/
n_tiles
;
}
if
(
parallel
*
n_tiles
>=
gridDim
.
x
)
{
// when parallel * n_tiles >= sms
// then there are at most $sms$ conflict tile blocks
locks_off
=
blockIdx
.
x
;
}
else
{
locks_off
=
(
iters
*
blockIdx
.
x
)
/
k_tiles
-
1
;
}
// Compute all information about the current slice which is required for
// synchronization.
auto
init_slice
=
[
&
](
bool
first_init
=
false
)
{
slice_iters
=
iters
*
(
blockIdx
.
x
+
1
)
-
(
k_tiles
*
slice_col_par
+
slice_row
);
if
(
slice_iters
<
0
||
slice_col_par
>=
n_tiles
*
parallel
)
slice_iters
=
0
;
if
(
slice_iters
==
0
)
return
;
if
(
slice_row
+
slice_iters
>
k_tiles
)
slice_iters
=
k_tiles
-
slice_row
;
slice_count
=
1
;
slice_idx
=
0
;
int
col_first
=
iters
*
div_ceil
(
k_tiles
*
slice_col_par
,
iters
);
if
(
col_first
<=
k_tiles
*
(
slice_col_par
+
1
))
{
int
col_off
=
col_first
-
k_tiles
*
slice_col_par
;
slice_count
=
div_ceil
(
k_tiles
-
col_off
,
iters
);
if
(
col_off
>
0
)
slice_count
++
;
int
delta_first
=
iters
*
blockIdx
.
x
-
col_first
;
if
(
delta_first
<
0
||
(
col_off
==
0
&&
delta_first
==
0
))
slice_idx
=
slice_count
-
1
;
else
{
slice_idx
=
slice_count
-
1
-
delta_first
/
iters
;
if
(
col_off
>
0
)
slice_idx
--
;
}
}
if
(
parallel
*
n_tiles
>=
gridDim
.
x
)
{
if
(
slice_count
>
1
&&
slice_idx
==
slice_count
-
1
)
{
locks_off
++
;
}
}
else
{
locks_off
++
;
}
if
(
first_init
&&
use_atomic_add
&&
slice_count
>
1
&&
slice_idx
==
0
)
{
constexpr
int
threads_per_m
=
16
*
thread_n_blocks
/
8
;
int
m_per_thread
=
div_ceil
(
thread_m_blocks
*
16
,
threads
/
threads_per_m
);
if
(
m_block_size_8
)
m_per_thread
=
div_ceil
(
8
,
threads
/
threads_per_m
);
for
(
int
i
=
0
;
i
<
m_per_thread
;
i
++
)
{
int
row
=
threads
/
threads_per_m
*
i
+
threadIdx
.
x
/
threads_per_m
;
if
(
row
<
prob_m
)
{
int
col
=
slice_col
*
16
*
thread_n_blocks
/
8
+
threadIdx
.
x
%
threads_per_m
;
C
[
row
*
prob_n
/
8
+
col
]
=
{
0
,
0
,
0
,
0
};
}
}
// After write zero to output, write a negative value to lock.
// Every SM that processes the same slice would wait for
// the negative value, and then atomicAdd 1 to it.
// After all SMs are processed, the lock value would back to 0 again.
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
locks
[
locks_off
]
=
1
-
slice_count
;
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
lda
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
slice_col
=
0
;
par_id
++
;
}
};
init_slice
(
true
);
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
lda
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
constexpr
int
a_gl_rd_delta_o
=
16
*
thread_k_blocks
/
8
;
// between subsequent accesses within a tile
int
a_gl_rd_delta_i
=
a_gl_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory writes
constexpr
int
a_sh_wr_delta
=
a_sh_stride
*
(
threads
/
a_gl_rd_delta_o
);
// between shared memory tile reads
constexpr
int
a_sh_rd_delta_o
=
2
*
((
threads
/
32
)
/
(
thread_n_blocks
/
4
));
// within a shared memory tile
constexpr
int
a_sh_rd_delta_i
=
a_sh_stride
*
16
;
// overall size of a tile
constexpr
int
a_sh_stage
=
a_sh_stride
*
m_block_size
;
// number of shared write iterations for a tile
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
w_type
.
size_bits
()
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride_threads
);
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
// Scale size/strides with act_order
constexpr
int
tb_k
=
16
*
thread_k_blocks
;
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
constexpr
int
act_s_max_num_groups
=
32
;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Zero-points sizes/strides
int
zp_gl_stride
=
is_zp_float
?
prob_n
/
8
:
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
is_zp_float
?
16
*
thread_n_blocks
/
8
:
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
constexpr
int
zp_tb_groups
=
s_tb_groups
;
constexpr
int
zp_sh_stage
=
has_zp
?
zp_tb_groups
*
zp_sh_stride
:
0
;
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
// Shared read index.
int
a_sh_rd
=
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
(
16
/
(
m_block_size_8
?
2
:
1
)))
+
(
threadIdx
.
x
%
32
)
/
(
16
/
(
m_block_size_8
?
2
:
1
));
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
auto
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
int
slice_k_start
=
tb_k
*
slice_row
;
int
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
int
slice_k_start_shared_fetch
=
slice_k_start
;
int
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
// No act_order
int
s_gl_rd
;
if
constexpr
(
!
has_act_order
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
auto
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
int
zp_gl_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
zp_gl_rd
=
zp_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
auto
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
has_zp
))
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
8
;
else
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
%
4
;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr
int
num_col_threads
=
8
;
constexpr
int
num_row_threads
=
4
;
constexpr
int
num_ints_per_thread
=
8
/
pack_factor
;
int
zp_sh_rd
;
if
constexpr
(
has_zp
)
{
if
constexpr
(
is_zp_float
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
zp_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
}
}
else
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
}
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool
a_sh_wr_pred
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_pred
[
i
]
=
a_sh_wr_delta
*
i
+
a_sh_wr
<
a_sh_stride
*
prob_m
;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
(
row
%
8
);
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int
a_sh_wr_trans
[
a_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
a_sh_wr_trans
[
i
]
=
transform_a
(
a_sh_wr_delta
*
i
+
a_sh_wr
);
int
a_sh_rd_trans
[
b_sh_wr_iters
][
thread_m_blocks
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
thread_m_blocks
;
j
++
)
a_sh_rd_trans
[
i
][
j
]
=
transform_a
(
a_sh_rd_delta_o
*
i
+
a_sh_rd_delta_i
*
j
+
a_sh_rd
);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const
int4
*
B_ptr
[
b_sh_wr_iters
];
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
extern
__shared__
int4
sh
[];
// Shared memory storage for global fetch pipelines.
constexpr
int
sh_red_size
=
(
2
*
thread_n_blocks
+
1
)
*
16
*
thread_m_blocks
;
constexpr
int
sh_b_size
=
stages
*
b_sh_stage
;
int4
*
sh_b
=
sh
;
int4
*
sh_red
=
sh
;
int4
*
sh_g_idx
=
sh_b
+
(
sh_red_size
>
sh_b_size
?
sh_red_size
:
sh_b_size
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
constexpr
int
sh_s_size
=
has_act_order
?
(
act_s_max_num_groups
*
s_sh_stride
)
:
(
stages
*
s_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert
(
thread_m_blocks
*
16
*
thread_n_blocks
*
16
/
8
<=
stages
*
b_sh_stage
);
int4
*
sh_a
=
sh_s
+
sh_s_size
;
// constexpr int shm_size_used =
// stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
// (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
][
b_thread_vecs
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
int
frag_qzp
[
2
][
num_ints_per_thread
];
// Zero-points
FragZP
frag_zp
;
// Zero-points in fp16
FragZP
frag_zpf
[
2
];
// Zero-points in fp16 in HQQ
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
*
4
*
2
*
4
;
i
++
)
reinterpret_cast
<
float
*>
(
frag_c
)[
i
]
=
0
;
};
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
auto
fetch_act_order_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
act_s_max_num_groups
)
{
sh_num_groups
=
act_s_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
sh_num_groups
=
num_groups
-
sh_first_group_id
;
}
int
row_offset
=
first_group_id
*
s_gl_stride
;
if
(
is_async
)
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
cp_async4_pred
(
&
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
],
&
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
]);
}
}
}
else
{
for
(
int
i
=
0
;
i
<
sh_num_groups
;
i
++
)
{
if
(
threadIdx
.
x
<
s_sh_stride
)
{
sh_s
[(
i
*
s_sh_stride
)
+
threadIdx
.
x
]
=
scales_ptr
[
row_offset
+
(
i
*
s_gl_stride
)
+
slice_n_offset
+
threadIdx
.
x
];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
],
a_sh_wr_pred
[
i
]);
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
if
constexpr
(
has_act_order
)
{
// Fetch g_idx thread-block portion
int
full_pipe
=
a_off
;
int
cur_k
=
slice_k_start_shared_fetch
+
tb_k
*
full_pipe
;
if
(
cur_k
<
prob_k
&&
cur_k
<
slice_k_finish
)
{
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int4
const
*
cur_g_idx_stage_ptr
=
reinterpret_cast
<
int4
const
*>
(
&
g_idx
[
cur_k
]);
if
(
threadIdx
.
x
<
g_idx_stage
)
{
cp_async4_pred
(
&
sh_g_idx_stage
[
threadIdx
.
x
],
&
cur_g_idx_stage_ptr
[
threadIdx
.
x
]);
}
}
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
// Only fetch zero-points if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
else
{
for
(
int
i
=
0
;
i
<
zp_tb_groups
;
i
++
)
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp_stage
[
i
*
zp_sh_stride
+
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
zp_gl_rd
+=
zp_gl_rd_delta
;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence
();
};
auto
fetch_col_zp_to_shared
=
[
&
]()
{
if
(
zp_sh_wr_pred
)
{
cp_async4
(
&
sh_zp
[
zp_sh_wr
],
&
zp_ptr
[
zp_gl_rd
]);
}
};
auto
fetch_col_scale_to_shared
=
[
&
]()
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto
wait_for_stage
=
[
&
]()
{
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait
<
stages
-
2
>
();
__syncthreads
();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm
<
m_block_size_8
?
2
:
4
,
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
if
constexpr
(
!
has_act_order
)
{
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
int
group_id_1
=
sh_g_idx_int_ptr
[
0
];
int
group_id_2
=
sh_g_idx_int_ptr
[
tb_k
-
1
];
is_same_group
[
pipe
]
=
group_id_1
==
group_id_2
;
same_group_id
[
pipe
]
=
group_id_1
;
};
auto
fetch_scales_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
!
has_act_order
)
{
// No act-order case
if
constexpr
(
group_blocks
==
-
1
)
{
// load only when starting a new slice
if
(
k
==
0
&&
full_pipe
==
0
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
}
}
else
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
}
return
;
}
// Act-order case
// Determine K of the "current" thread-block
int
cur_k
=
slice_k_start
+
tb_k
*
full_pipe
;
if
(
cur_k
>=
prob_k
||
cur_k
>=
slice_k_finish
)
{
return
;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k
=
0
;
// Progress to current iteration
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
int
warp_row
=
warp_id
/
n_warps
;
int
warp_col
=
warp_id
%
n_warps
;
cur_k
+=
warp_row
*
16
;
auto
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
/*slice_n_offset +*/
(
act_s_col_warp_stride
*
warp_col
)
+
(
th_id
/
4
)
*
act_s_col_stride
;
if
(
is_same_group
[
pipe
])
{
if
(
k
%
2
==
0
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
sh_s
[(
same_group_id
[
pipe
]
-
sh_first_group_id
)
*
s_sh_stride
+
s_col_shift
];
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[(
k
-
1
)
%
2
][
0
][
0
])));
}
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
0
][
0
])));
}
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
constexpr
int
k_frag_offsets
[
4
]
=
{
0
,
1
,
8
,
9
};
// Tensor core offsets per thread
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
actual_k
=
cur_k
+
k_frag_offsets
[
i
];
int
group_id
=
sh_g_idx_int_ptr
[
actual_k
];
int
rel_group_id
=
group_id
-
sh_first_group_id
;
*
(
reinterpret_cast
<
int4
*>
(
&
(
act_frag_s
[
k
%
2
][
i
][
0
])))
=
sh_s
[
rel_group_id
*
s_sh_stride
+
s_col_shift
];
}
};
auto
fetch_zp_to_registers
=
[
&
](
int
k
,
int
full_pipe
)
{
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
// load only when starting a new slice
if
(
k
==
0
&&
full_pipe
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp
))[
zp_sh_rd
+
i
];
}
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
0
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
if
constexpr
(
has_zp
&&
is_zp_float
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
}
}
else
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
+
cur_group_id
*
zp_sh_stride
];
}
}
}
};
auto
dequant_data
=
[
&
](
int
q
,
scalar_t2
*
frag_b_ptr
)
{
if
constexpr
(
has_zp
&&
is_zp_float
||
!
has_zp
)
{
dequant
<
scalar_t2
,
w_type_id
>
(
q
,
frag_b_ptr
);
}
else
{
static_assert
(
has_zp
&&
!
is_zp_float
);
static_assert
(
w_type_id
==
vllm
::
kU4
.
id
()
||
w_type_id
==
vllm
::
kU8
.
id
());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if
constexpr
(
w_type_id
==
vllm
::
kU4
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU4B8
.
id
()
>
(
q
,
frag_b_ptr
);
}
else
if
constexpr
(
w_type_id
==
vllm
::
kU8
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU8B128
.
id
()
>
(
q
,
frag_b_ptr
);
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
bool
is_first_matmul_in_slice
=
true
;
auto
matmul
=
[
&
](
int
k
)
{
int
k2
=
k
%
2
;
const
bool
is_new_zp
=
((
group_blocks
!=
-
1
)
&&
(
group_blocks
<
thread_k_blocks
||
k
==
0
))
||
(
group_blocks
==
-
1
&&
is_first_matmul_in_slice
);
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
if
(
is_new_zp
)
{
if
constexpr
(
group_blocks
==
-
1
)
is_first_matmul_in_slice
=
false
;
FragB
frag_zp_0
;
FragB
frag_zp_1
;
int
zp_quant_0
,
zp_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
zp_quant_0
=
frag_qzp
[
k2
][
0
];
zp_quant_1
=
zp_quant_0
>>
8
;
}
else
{
static_assert
(
w_type
.
size_bits
()
==
8
);
zp_quant_0
=
frag_qzp
[
k2
][
0
];
zp_quant_1
=
frag_qzp
[
k2
][
1
];
}
dequant_data
(
zp_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
));
dequant_data
(
zp_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
)
+
2
);
}
}
if
constexpr
(
has_zp
&&
is_zp_float
)
{
if
(
is_new_zp
)
{
reinterpret_cast
<
int4
*>
(
&
frag_zp
)[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k2
])[
0
];
}
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
static_assert
(
w_type
.
size_bits
()
==
8
);
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k2
]);
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
}
dequant_data
(
b_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b0
));
dequant_data
(
b_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b1
));
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
static_assert
(
group_blocks
!=
-
1
);
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
0
);
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
int
idx
=
(
threadIdx
.
x
/
4
)
%
2
;
scalar_t2
s2
=
Dtype
::
nums2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
[
j
/
2
][
j
%
2
*
2
+
0
])[
idx
],
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
[
j
/
2
][
j
%
2
*
2
+
1
])[
idx
]);
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
s2
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
s2
.
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
s2
.
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
][
0
].
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
][
0
].
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
],
0
);
scale
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
],
1
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
if
constexpr
(
m_block_size_8
)
{
mma_trans
<
scalar_t
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_b1
,
frag_c
[
i
][
j
][
0
]);
}
else
{
mma
<
scalar_t
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
scalar_t
>
(
frag_a
[
k2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
(
threadIdx
.
x
%
b_sh_stride_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for
(
int
m_block
=
0
;
m_block
<
thread_m_blocks
;
m_block
++
)
{
#pragma unroll
for
(
int
i
=
red_off
;
i
>
0
;
i
/=
2
)
{
if
(
i
<=
red_idx
&&
red_idx
<
2
*
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
2
;
j
+=
(
m_block_size_8
?
2
:
1
))
{
int
red_sh_wr
=
red_sh_delta
*
j
+
(
red_sh_rd
-
red_sh_stride
*
i
);
if
(
i
<
red_off
)
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
red_sh_delta
*
j
+
red_sh_rd
]);
float
*
c_wr
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
red_sh_wr
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
j
][
k
]
+=
c_rd
[
k
]
+
c_wr
[
k
];
}
sh_red
[
red_sh_wr
]
=
reinterpret_cast
<
int4
*>
(
&
frag_c
)[
4
*
2
*
m_block
+
j
];
}
}
__syncthreads
();
}
if
(
red_idx
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
*
2
;
i
+=
(
m_block_size_8
?
2
:
1
))
{
float
*
c_rd
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
red_sh_delta
*
i
+
red_sh_rd
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
FragC
*>
(
frag_c
)[
4
*
2
*
m_block
+
i
][
j
]
+=
c_rd
[
j
];
}
}
__syncthreads
();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce_fp16
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
if
(
threadIdx
.
x
<
active_threads
)
{
int
c_gl_stride
=
prob_n
/
8
;
int
c_gl_wr_delta_o
=
8
*
c_gl_stride
;
int
c_gl_wr_delta_i
=
4
*
(
active_threads
/
32
);
int
c_gl_wr
;
if
constexpr
(
m_block_size_8
)
{
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
4
)
*
2
)
+
4
*
(
threadIdx
.
x
/
32
)
+
(
threadIdx
.
x
%
32
)
/
8
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
}
else
{
c_gl_wr
=
c_gl_stride
*
((
threadIdx
.
x
%
32
)
/
4
)
+
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
}
constexpr
int
c_sh_wr_delta
=
active_threads
;
auto
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
if
(
!
first
)
{
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for
(
int
i
=
0
;
i
<
(
m_block_size_8
?
2
:
thread_m_blocks
*
4
);
i
++
)
{
if
constexpr
(
m_block_size_8
)
{
cp_async4_pred
(
&
sh_red
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
i
*
c_gl_stride
+
(
threadIdx
.
x
%
8
)
/
4
*
c_gl_wr_delta_i
],
(
threadIdx
.
x
%
4
)
*
2
+
i
<
prob_m
);
}
else
{
cp_async4_pred
(
&
sh_red
[
c_sh_wr
+
c_sh_wr_delta
*
i
],
&
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)],
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
);
}
}
cp_async_fence
();
cp_async_wait
<
0
>
();
}
#pragma unroll
for
(
int
i
=
0
;
i
<
(
m_block_size_8
?
2
:
thread_m_blocks
*
4
);
i
++
)
{
bool
mask
=
(
!
m_block_size_8
)
&&
(
i
<
(
thread_m_blocks
-
1
)
*
4
||
8
*
(
i
/
2
)
+
row
<
prob_m
)
||
(
m_block_size_8
)
&&
((
threadIdx
.
x
%
4
)
*
2
+
i
<
prob_m
);
if
(
mask
)
{
if
(
!
first
)
{
int4
c_red
=
sh_red
[
c_sh_wr
+
i
*
c_sh_wr_delta
];
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
int
delta
=
0
;
if
constexpr
(
m_block_size_8
)
{
delta
=
j
%
2
==
1
?
-
2
:
0
;
}
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)
+
delta
]
+=
Dtype
::
num2float
(
reinterpret_cast
<
scalar_t
*>
(
&
c_red
)[
j
]);
}
}
if
(
!
last
)
{
int4
c
;
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
int
delta
=
0
;
if
constexpr
(
m_block_size_8
)
{
delta
=
j
%
2
==
1
?
-
2
:
0
;
}
reinterpret_cast
<
scalar_t
*>
(
&
c
)[
j
]
=
Dtype
::
float2num
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)
+
delta
]);
}
if
constexpr
(
m_block_size_8
)
C
[
c_gl_wr
+
i
*
c_gl_stride
+
(
threadIdx
.
x
%
8
)
/
4
*
c_gl_wr_delta_i
]
=
c
;
else
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
c
;
}
}
}
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto
global_reduce_fp32
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
constexpr
int
tb_m
=
thread_m_blocks
*
16
;
constexpr
int
tb_n
=
thread_n_blocks
*
16
;
constexpr
int
c_size
=
tb_m
*
tb_n
*
sizeof
(
float
)
/
16
;
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
bool
is_th_active
=
threadIdx
.
x
<
active_threads
;
constexpr
int
num_floats
=
thread_m_blocks
*
4
*
2
*
4
;
constexpr
int
th_size
=
num_floats
*
sizeof
(
float
)
/
16
;
int
c_cur_offset
=
locks_off
*
c_size
;
if
(
!
is_th_active
)
{
return
;
}
if
(
!
first
)
{
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
+=
(
m_block_size_8
?
2
:
1
))
{
sh_red
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh_red
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
}
}
}
if
(
!
last
)
{
int4
*
frag_c_ptr
=
reinterpret_cast
<
int4
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
+=
(
m_block_size_8
?
2
:
1
))
{
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
]
=
frag_c_ptr
[
k
];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto
write_result
=
[
&
]()
{
int
c_gl_stride
=
prob_n
/
8
;
constexpr
int
c_sh_stride
=
2
*
thread_n_blocks
+
1
;
int
c_gl_wr_delta
=
c_gl_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
constexpr
int
c_sh_rd_delta
=
c_sh_stride
*
(
threads
/
(
2
*
thread_n_blocks
));
int
c_gl_wr
=
c_gl_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
int
c_sh_wr
;
if
constexpr
(
m_block_size_8
)
{
c_sh_wr
=
(
8
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
%
4
*
2
)
+
(
threadIdx
.
x
%
32
)
/
4
;
c_sh_wr
+=
64
*
(
threadIdx
.
x
/
32
);
}
else
{
c_sh_wr
=
(
4
*
c_sh_stride
)
*
((
threadIdx
.
x
%
32
)
/
4
)
+
(
threadIdx
.
x
%
32
)
%
4
;
c_sh_wr
+=
32
*
(
threadIdx
.
x
/
32
);
}
int
c_sh_rd
=
c_sh_stride
*
(
threadIdx
.
x
/
(
2
*
thread_n_blocks
))
+
(
threadIdx
.
x
%
(
2
*
thread_n_blocks
));
int
c_gl_wr_end
=
c_gl_stride
*
prob_m
;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
scalar_t2
res
=
Dtype
::
nums2num2
(
Dtype
::
float2num
(
c0
),
Dtype
::
float2num
(
c1
));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
&&
!
has_zp
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
if
constexpr
(
m_block_size_8
)
{
((
scalar_t
*
)
sh_red
)[
idx
]
=
res
.
x
;
((
scalar_t
*
)
sh_red
)[
idx
+
8
*
c_sh_stride
]
=
res
.
y
;
}
else
{
((
scalar_t2
*
)
sh_red
)[
idx
]
=
res
;
}
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
if
constexpr
(
m_block_size_8
)
{
int
wr
=
c_sh_wr
+
16
*
j
;
write
(
wr
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
8
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
else
{
int
wr
=
c_sh_wr
+
8
*
j
;
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
0
,
frag_c
[
i
][
j
][
0
][
0
],
frag_c
[
i
][
j
][
0
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
0
,
frag_c
[
i
][
j
][
0
][
2
],
frag_c
[
i
][
j
][
0
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
0
+
4
,
frag_c
[
i
][
j
][
1
][
0
],
frag_c
[
i
][
j
][
1
][
1
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
write
(
wr
+
(
4
*
c_sh_stride
)
*
8
+
4
,
frag_c
[
i
][
j
][
1
][
2
],
frag_c
[
i
][
j
][
1
][
3
],
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
c_sh_wr
+=
16
*
(
4
*
c_sh_stride
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
div_ceil
(
16
*
thread_m_blocks
,
threads
/
(
2
*
thread_n_blocks
));
i
++
)
{
if
(
c_gl_wr
<
c_gl_wr_end
)
{
if
(
use_atomic_add
&&
slice_count
>
1
)
{
scalar_t2
*
C_half2
=
reinterpret_cast
<
scalar_t2
*>
(
&
C
[
c_gl_wr
]);
scalar_t2
*
sh_red_half2
=
reinterpret_cast
<
scalar_t2
*>
(
&
sh_red
[
c_sh_rd
]);
#pragma unroll
for
(
int
a
=
0
;
a
<
4
;
a
++
)
{
atomicAdd
(
&
C_half2
[
a
],
sh_red_half2
[
a
]);
}
}
else
{
C
[
c_gl_wr
]
=
sh_red
[
c_sh_rd
];
}
c_gl_wr
+=
c_gl_wr_delta
;
c_sh_rd
+=
c_sh_rd_delta
;
}
}
__syncthreads
();
};
// Start global fetch and register load pipelines.
auto
start_pipes
=
[
&
]()
{
#pragma unroll
for
(
int
i
=
0
;
i
<
stages
-
1
;
i
++
)
{
if
(
has_act_order
&&
i
==
0
)
{
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
fetch_act_order_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_col_zp_to_shared
();
fetch_col_scale_to_shared
();
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
}
zero_accums
();
wait_for_stage
();
init_same_group
(
0
);
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
if
constexpr
(
has_act_order
)
{
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
}
};
if
(
slice_iters
)
{
start_pipes
();
}
// Main loop.
while
(
slice_iters
)
{
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
if
constexpr
(
has_act_order
)
{
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_act_order_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
cp_async_wait
<
0
>
();
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
if
constexpr
(
m_block_size_8
)
{
int
idx
=
(
threadIdx
.
x
/
4
)
%
2
;
scalar_t2
*
frag_s_half2
=
reinterpret_cast
<
scalar_t2
*>
(
frag_s
);
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
frag_s_half2
[
i
]
=
Dtype
::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_half2
[
i
])[
idx
]);
}
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
&&
!
has_zp
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
(
m_block_size_8
?
1
:
0
)]);
if
constexpr
(
!
m_block_size_8
)
{
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
}
}
if
(
slice_count
>
1
&&
!
use_atomic_add
)
{
// only globally reduce if there is more than one block in a slice
barrier_acquire
(
&
locks
[
locks_off
],
slice_idx
);
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
}
else
{
global_reduce_fp16
(
slice_idx
==
0
,
last
);
}
barrier_release
(
&
locks
[
locks_off
],
last
);
}
if
(
use_atomic_add
&&
slice_count
>
1
&&
slice_idx
!=
0
)
wait_negative_and_add
(
&
locks
[
locks_off
]);
if
(
last
||
use_atomic_add
)
// only the last block in a slice actually writes the result
write_result
();
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
is_first_matmul_in_slice
=
true
;
init_slice
();
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
if
(
slice_col
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
-=
b_gl_stride
;
}
// Update slice k/n for scales loading
if
constexpr
(
has_act_order
)
{
slice_k_start
=
tb_k
*
slice_row
;
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
}
}
}
// namespace MARLIN_NAMESPACE_NAME
#endif
csrc/torch_bindings.cpp
View file @
1d0c9d6b
...
...
@@ -291,12 +291,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"gptq_marlin_gemm(Tensor a, Tensor
b_q_weight
, Tensor b_
scales
, "
"Tensor b_
zero
s, Tensor
g_idx, Tensor perm, Tensor workspac
e, "
"int b_q_type, "
"gptq_marlin_gemm(Tensor a, Tensor
? c_or_none
, Tensor b_
q_weight
, "
"Tensor b_
scale
s, Tensor
? b_zeros_or_none, Tensor? g_idx_or_non
e, "
"
Tensor? perm_or_none, Tensor workspace,
int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
"bool is_zp_float) -> Tensor"
,
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"
,
{
stride_tag
});
// conditionally compiled so impl registration is in source file
...
...
@@ -341,14 +340,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"ggml_moe_get_block_size"
,
&
ggml_moe_get_block_size
);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor"
,
{
stride_tag
});
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops
.
def
(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
...
...
tests/kernels/moe/test_moe.py
View file @
1d0c9d6b
...
...
@@ -11,19 +11,20 @@ from transformers import MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
(
opcheck
,
stack_and_dev
,
torch_moe
,
torch_moe_single
)
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
awq_marlin_quantize
,
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
quantize_weights
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
...
...
@@ -285,7 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol
=
mixtral_moe_tol
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
3
3
,
123
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
12
3
,
666
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
256
,
2048
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
12
])
...
...
@@ -294,8 +295,10 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
scalar_types
.
uint4
,
scalar_types
.
uint8b128
,
scalar_types
.
uint4b8
,
scalar_types
.
float8_e4m3fn
])
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_fused_marlin_moe
(
...
...
@@ -308,14 +311,22 @@ def test_fused_marlin_moe(
dtype
:
torch
.
dtype
,
group_size
:
int
,
act_order
:
bool
,
num_bits
:
int
,
has_zp
:
bool
,
quant_type
:
ScalarType
,
is_k_full
:
bool
,
):
current_platform
.
seed_everything
(
7
)
torch
.
cuda
.
manual_seed
(
0
)
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
if
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]:
return
if
act_order
:
return
# Filter act_order
if
act_order
:
if
quant_type
==
scalar_types
.
float8_e4m3fn
:
return
if
group_size
==
-
1
:
return
if
group_size
in
(
k
,
n
):
...
...
@@ -326,17 +337,9 @@ def test_fused_marlin_moe(
if
not
is_k_full
:
return
if
has_zp
:
# we don't build kernel for int8 with zero
if
num_bits
==
8
:
return
quant_type
=
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
else
:
quant_type
=
scalar_types
.
uint4b8
\
if
num_bits
==
4
else
scalar_types
.
uint8b128
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
1
0
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
1
0
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
2
0
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
2
0
if
ep_size
>
1
:
local_e
=
e
//
ep_size
...
...
@@ -364,17 +367,23 @@ def test_fused_marlin_moe(
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
zeros1_l
.
append
(
zeros1
)
el
se
:
el
if
quant_type
!=
scalar_types
.
float8_e4m3fn
:
test_perm
=
torch
.
randperm
(
k
)
quant_res
=
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
quant_res
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
\
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
else
:
w_ref1
,
qweight1
,
scales1
=
marlin_quant_fp8_torch
(
w1
[
i
],
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
...
...
@@ -399,17 +408,23 @@ def test_fused_marlin_moe(
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
zeros2_l
.
append
(
zeros2
)
el
se
:
el
if
quant_type
!=
scalar_types
.
float8_e4m3fn
:
test_perm
=
torch
.
randperm
(
n
)
quant_res
=
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
quant_res
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
\
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
else
:
w_ref2
,
qweight2
,
scales2
=
marlin_quant_fp8_torch
(
w2
[
i
],
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
...
...
@@ -442,102 +457,10 @@ def test_fused_marlin_moe(
sort_indices2
=
sort_indices2
,
w1_zeros
=
zeros1
,
w2_zeros
=
zeros2
,
num_bits
=
num_bits
,
quant_type_id
=
quant_type
.
id
,
is_k_full
=
is_k_full
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
123
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
256
,
2048
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
12
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
[
True
,
False
])
def
test_single_marlin_moe_multiply
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
act_order
:
bool
,
num_bits
:
int
,
has_zp
:
bool
,
is_k_full
:
bool
):
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
in
(
k
,
n
):
return
if
has_zp
:
return
else
:
if
not
is_k_full
:
return
if
has_zp
:
quant_type
=
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
else
:
quant_type
=
scalar_types
.
uint4b8
\
if
num_bits
==
4
else
scalar_types
.
uint8b128
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref_l
=
[]
qweight_l
=
[]
scales_l
=
[]
zeros_l
=
[]
g_idx_l
=
[]
sort_indices_l
=
[]
for
i
in
range
(
w
.
shape
[
0
]):
if
has_zp
:
w_ref
,
qweight
,
scales
,
zeros
=
awq_marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref_l
.
append
(
w_ref
.
T
)
qweight_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
zeros_l
.
append
(
zeros
)
else
:
test_perm
=
torch
.
randperm
(
k
)
w_ref
,
qweight
,
scales
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref_l
.
append
(
w_ref
.
T
)
qweight_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
g_idx_l
.
append
(
g_idx
)
sort_indices_l
.
append
(
sort_indices
)
w_ref
=
stack_and_dev
(
w_ref_l
)
qweight
=
stack_and_dev
(
qweight_l
).
contiguous
()
scales
=
stack_and_dev
(
scales_l
)
g_idx
=
stack_and_dev
(
g_idx_l
)
if
g_idx_l
else
None
zeros
=
stack_and_dev
(
zeros_l
)
if
zeros_l
else
None
sort_indices
=
stack_and_dev
(
sort_indices_l
)
if
sort_indices_l
else
None
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
torch
.
ops
.
vllm
.
single_marlin_moe
(
a
,
qweight
,
scales
,
score
,
topk
,
renormalize
=
False
,
g_idx
=
g_idx
,
sort_indices
=
sort_indices
,
w_zeros
=
zeros
,
num_bits
=
num_bits
,
is_k_full
=
is_k_full
,
)
torch_output
=
torch_moe_single
(
a
,
w_ref
,
score
,
topk
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
5e-2
,
rtol
=
0
)
def
test_moe_align_block_size_opcheck
():
...
...
tests/kernels/quantization/test_awq_marlin.py
deleted
100644 → 0
View file @
f62cad64
# SPDX-License-Identifier: Apache-2.0
"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import
pytest
import
torch
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
(
compute_max_diff
,
stack_and_dev
,
torch_moe
,
torch_moe_single
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
awq_marlin_quantize
)
from
vllm.scalar_type
import
scalar_types
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
2
,
6
]
GROUP_SIZES
=
[
-
1
,
32
,
128
]
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GROUP_SIZES
)
@
pytest
.
mark
.
skipif
(
not
(
ops
.
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
)),
reason
=
"Marlin is not supported on this GPU type."
)
def
test_fused_marlin_moe_awq
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
):
torch
.
manual_seed
(
7
)
num_bits
=
4
quant_type
=
scalar_types
.
uint4
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref1_l
=
[]
qweights1_l
=
[]
scales1_l
=
[]
zp1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
w_ref1
,
qweight1
,
scales1
,
zp1
=
awq_marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref1_l
.
append
(
w_ref1
)
qweights1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
zp1_l
.
append
(
zp1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweights1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
zp1
=
stack_and_dev
(
zp1_l
)
w_ref2_l
=
[]
qweights2_l
=
[]
scales2_l
=
[]
zp2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
w_ref2
,
qweight2
,
scales2
,
zp2
=
awq_marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref2_l
.
append
(
w_ref2
)
qweights2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
zp2_l
.
append
(
zp2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweights2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
zp2
=
stack_and_dev
(
zp2_l
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
False
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
scales1
,
scales2
,
score
,
topk_weights
,
topk_ids
,
w1_zeros
=
zp1
,
w2_zeros
=
zp2
,
num_bits
=
num_bits
,
)
torch_output
=
torch_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
score
,
topk
,
None
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
4e-2
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
64
,
128
])
def
test_single_marlin_moe_multiply_awq
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
):
torch
.
manual_seed
(
7
)
num_bits
=
4
quant_type
=
scalar_types
.
uint4
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref_l
=
[]
qweights_l
=
[]
scales_l
=
[]
zp_l
=
[]
for
i
in
range
(
w
.
shape
[
0
]):
w_ref
,
qweight
,
scales
,
zp
=
awq_marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref_l
.
append
(
w_ref
)
qweights_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
zp_l
.
append
(
zp
)
w_ref
=
stack_and_dev
(
w_ref_l
)
qweight
=
stack_and_dev
(
qweights_l
).
contiguous
()
scales
=
stack_and_dev
(
scales_l
).
contiguous
()
zp
=
stack_and_dev
(
zp_l
).
contiguous
()
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
torch
.
ops
.
vllm
.
single_marlin_moe
(
a
,
qweight
,
scales
,
score
,
topk
,
renormalize
=
False
,
w_zeros
=
zp
,
num_bits
=
num_bits
)
torch_output
=
torch_moe_single
(
a
,
w_ref
.
transpose
(
1
,
2
),
score
,
topk
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
1e-2
tests/kernels/quantization/test_marlin_gemm.py
View file @
1d0c9d6b
...
...
@@ -18,9 +18,10 @@ from vllm.model_executor.layers.quantization.qqq import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
marlin_make_empty_g_idx
,
marlin_permute_scales
,
query_marlin_supported_quant_types
)
marlin_make_workspace_new
,
marlin_permute_scales
,
query_marlin_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
pack
_fp8_to
_int32
)
marlin_quant
_fp8_to
rch
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
awq_marlin_quantize
,
get_weight_perm
,
marlin_quantize
,
marlin_weights
)
...
...
@@ -73,7 +74,7 @@ def rand_data(shape, dtype=torch.float16):
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
False
))
query_marlin_supported_quant_types
(
False
,
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
...
...
@@ -138,7 +139,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
Fals
e
))
query_marlin_supported_quant_types
(
Tru
e
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_awq_marlin_repack
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
...
...
@@ -220,38 +221,50 @@ def test_gptq_marlin_gemm(
if
group_size
==
size_k
:
return
if
size_k
%
group_size
!=
0
:
return
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
quant_type
,
group_size
,
act_order
)
if
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]:
return
if
act_order
:
return
w_ref
,
marlin_q_w
,
marlin_s
=
marlin_quant_fp8_torch
(
b_weight
.
T
,
group_size
)
g_idx
=
None
sort_indices
=
None
else
:
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
quant_type
,
group_size
,
act_order
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
workspace
=
marlin_make_workspace_new
(
w_ref
.
device
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
]
,
b_weight
.
shape
[
1
]
,
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_atomic_add
,
use_fp32_reduce
,
False
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_
id
x
,
sort_indices
,
workspace
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
]
,
a_input
.
shape
[
1
],
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
False
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
workspace
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
...
...
@@ -326,80 +339,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_fp8_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
dtype
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
a_input
=
rand_data
((
size_m
,
size_k
),
dtype
=
dtype
)
b_weight
=
rand_data
((
size_k
,
size_n
),
dtype
=
dtype
)
# WEIGHTS
fp8_weight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
b_weight
,
scale
=
None
)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
fp8_weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
weight_scale
.
repeat
(
1
,
size_n
).
to
(
a_input
.
dtype
).
to
(
"cuda"
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=-
1
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
opcheck
(
torch
.
ops
.
_C
.
fp8_marlin_gemm
,
(
a_input
,
marlin_qweight
,
marlin_scales
,
workspace
.
scratch
,
num_bits
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]))
output
=
ops
.
fp8_marlin_gemm
(
a
=
a_input
,
b_q_weight
=
marlin_qweight
,
b_scales
=
marlin_scales
,
workspace
=
workspace
.
scratch
,
num_bits
=
num_bits
,
size_m
=
a_input
.
shape
[
0
],
size_n
=
b_weight
.
shape
[
1
],
size_k
=
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
a_input
,
b_weight
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
...
...
@@ -432,25 +371,23 @@ def test_awq_marlin_gemm(
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
is_k_full
=
True
has_zp
=
True
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
workspace
=
marlin_make_workspace_new
(
a_input
.
device
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
workspace
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
)
...
...
@@ -508,23 +445,22 @@ def test_hqq_marlin_gemm(
g_idx
=
marlin_make_empty_g_idx
(
dev
)
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
dev
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
workspace
=
marlin_make_workspace_new
(
b_weight
.
device
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
None
,
marlin_w_q
,
marlin_s
,
marlin_zp
,
g_idx
,
g_idx_sort_indices
,
workspace
.
scratch
,
workspace
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
0
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
has_zp
=
True
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
True
,
)
...
...
@@ -621,23 +557,22 @@ def test_marlin_gemm_subset_input():
b_weight
,
quant_type
,
group_size
,
False
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
workspace
=
marlin_make_workspace_new
(
a_input
.
device
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
workspace
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
has_zp
=
False
,
use_atomic_add
=
False
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
,
...
...
vllm/_custom_ops.py
View file @
1d0c9d6b
...
...
@@ -325,18 +325,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@
register_fake
(
"_C::gptq_marlin_gemm"
)
def
_gptq_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
c
:
Optional
[
torch
.
Tensor
],
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_zeros
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
b_zeros
:
Optional
[
torch
.
Tensor
]
,
g_idx
:
Optional
[
torch
.
Tensor
]
,
perm
:
Optional
[
torch
.
Tensor
]
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
b_q_type
_id
:
int
,
size_m
:
torch
.
SymInt
,
size_n
:
torch
.
SymInt
,
size_k
:
torch
.
SymInt
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
is_k_full
:
bool
=
True
,
use_atomic_add
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
,
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -407,14 +407,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype
=
codebooks
.
dtype
,
device
=
codebooks
.
device
)
@
register_fake
(
"_C::fp8_marlin_gemm"
)
def
_fp8_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
torch
.
SymInt
,
size_n
:
torch
.
SymInt
,
size_k
:
torch
.
SymInt
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
@
register_fake
(
"_C::machete_mm"
)
def
machete_mm_fake
(
a
:
torch
.
Tensor
,
...
...
@@ -815,35 +807,26 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
c
:
Optional
[
torch
.
Tensor
],
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_zeros
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
b_zeros
:
Optional
[
torch
.
Tensor
]
,
g_idx
:
Optional
[
torch
.
Tensor
]
,
perm
:
Optional
[
torch
.
Tensor
]
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
is_k_full
:
bool
=
True
,
use_atomic_add
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
,
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
c
,
b_q_weight
,
b_scales
,
b_zeros
,
g_idx
,
perm
,
workspace
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
has_zp
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
# fp8 marlin
def
fp8_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
fp8_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
)
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
# machete
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
1d0c9d6b
...
...
@@ -7,163 +7,13 @@ import torch
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.scalar_type
import
scalar_types
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_make_workspace_new
,
maybe_warn_marlin_atomic_add
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
direct_register_custom_op
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
if
has_zp
:
return
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
else
:
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
- sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w
.
shape
[
1
]
*
16
,
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
num_bits
in
[
4
,
8
]
M
,
K
=
hidden_states
.
shape
E
=
w
.
shape
[
0
]
N
=
w
.
shape
[
2
]
//
(
num_bits
//
2
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
# This might not be an optimal config for a single MMM
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w
.
shape
,
w
.
shape
,
topk_ids
.
shape
[
1
],
None
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
,
expert_map
)
if
workspace
is
None
:
max_workspace_size
=
(
max
(
2
*
N
,
K
)
//
64
)
*
\
(
sorted_token_ids
.
size
(
0
)
//
block_size_m
)
device
=
hidden_states
.
device
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
max_workspace_size
=
min
(
max_workspace_size
,
sms
)
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
scalar_type
=
get_scalar_type
(
num_bits
,
w_zeros
is
not
None
)
intermediate_cache
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
ops
.
moe_wna16_marlin_gemm
(
hidden_states
,
intermediate_cache
,
w
,
scales
,
w_zeros
,
g_idx
,
sort_indices
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
topk_weights
,
moe_block_size
=
block_size_m
,
top_k
=
topk
,
mul_topk_weights
=
False
,
is_ep
=
expert_map
is
not
None
,
b_q_type
=
scalar_type
,
size_m
=
M
,
size_n
=
N
,
size_k
=
K
,
is_k_full
=
is_k_full
,
use_atomic_add
=
False
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
)
intermediate_cache
=
intermediate_cache
.
view
(
-
1
,
topk
,
N
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
def
single_marlin_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"single_marlin_moe"
,
op_func
=
single_marlin_moe
,
mutates_args
=
[],
fake_impl
=
single_marlin_moe_fake
,
)
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -172,6 +22,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
quant_type_id
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -181,7 +32,6 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
"""
...
...
@@ -211,6 +61,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
quant_type
=
ScalarType
.
from_id
(
quant_type_id
)
assert
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8b128
,
scalar_types
.
uint4b8
,
scalar_types
.
float8_e4m3fn
]
int4_scalar_types
=
[
scalar_types
.
uint4
,
scalar_types
.
uint4b8
]
num_bits
=
4
if
quant_type
in
int4_scalar_types
else
8
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
...
...
@@ -248,18 +107,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
expert_map
)
if
workspace
is
None
:
max_workspace_size
=
(
max
(
2
*
N
,
K
)
//
64
)
*
\
(
sorted_token_ids
.
size
(
0
)
//
block_size_m
)
device
=
hidden_states
.
device
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
max_workspace_size
=
min
(
max_workspace_size
,
sms
*
4
)
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
scalar_type1
=
get_scalar_type
(
num_bits
,
w1_zeros
is
not
None
)
scalar_type2
=
get_scalar_type
(
num_bits
,
w2_zeros
is
not
None
)
workspace
=
marlin_make_workspace_new
(
hidden_states
.
device
,
4
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
...
...
@@ -276,6 +124,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
K
]
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
maybe_warn_marlin_atomic_add
(
hidden_states
.
device
,
hidden_states
.
dtype
)
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
torch
.
cuda
.
get_device_capability
(
hidden_states
.
device
)[
0
]
>=
9
...
...
@@ -296,7 +145,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
top_k
=
topk
,
mul_topk_weights
=
False
,
is_ep
=
expert_map
is
not
None
,
b_q_type
=
scalar
_type
1
,
b_q_type
=
quant
_type
,
size_m
=
M
,
size_n
=
2
*
N
,
size_k
=
K
,
...
...
@@ -328,7 +177,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
top_k
=
1
,
mul_topk_weights
=
True
,
is_ep
=
expert_map
is
not
None
,
b_q_type
=
scalar
_type
2
,
b_q_type
=
quant
_type
,
size_m
=
M
*
topk
,
size_n
=
K
,
size_k
=
N
,
...
...
@@ -351,6 +200,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
quant_type_id
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -360,7 +210,6 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
1d0c9d6b
...
...
@@ -22,9 +22,10 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
check_marlin_supports_layer
,
check_moe_marlin_supports_layer
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
marlin_make_empty_g_idx
,
marlin_make_workspace_new
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
...
...
@@ -267,8 +268,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
requires_grad
=
False
)
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
layer
.
workspace
=
marlin_make_workspace_new
(
device
)
# Repack weights from AWQ format to marlin format.
marlin_qweight
=
ops
.
awq_marlin_repack
(
...
...
@@ -322,6 +322,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
AWQMarlinConfig
):
self
.
quant_config
=
quant_config
if
self
.
quant_config
.
weight_bits
!=
4
:
raise
ValueError
(
"AWQMoEMethod only supports 4bit now."
)
self
.
quant_type
=
scalar_types
.
uint4
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
...
...
@@ -396,11 +399,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
device
=
layer
.
w13_qweight
.
device
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
layer
.
workspace
=
torch
.
zeros
((
sms
*
4
,
),
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
num_experts
=
layer
.
w13_qweight
.
shape
[
0
]
...
...
@@ -511,10 +510,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits
,
topk_weights
,
topk_ids
,
quant_type_id
=
self
.
quant_type
.
id
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_zeros
=
layer
.
w13_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
workspace
=
layer
.
workspace
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
1d0c9d6b
...
...
@@ -55,7 +55,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
# required by torch.compile to be torch.nn.Parameter
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
prepare_fp8_layer_for_marlin
(
layer
,
strategy
=
"channel"
)
prepare_fp8_layer_for_marlin
(
layer
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
...
...
@@ -68,6 +68,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# WEIGHT
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment