Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d74e5f37
Unverified
Commit
d74e5f37
authored
May 11, 2025
by
Jinzhen Lin
Committed by
GitHub
May 10, 2025
Browse files
[Kernel] fp4 marlin kernel (#17687)
Signed-off-by:
Jinzhen Lin
<
linjinzhen@hotmail.com
>
parent
ca66a167
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1191 additions
and
328 deletions
+1191
-328
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+3
-0
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+11
-2
csrc/moe/marlin_moe_wna16/kernel.h
csrc/moe/marlin_moe_wna16/kernel.h
+12
-11
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+99
-40
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+66
-18
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+2
-1
csrc/quantization/gptq_marlin/dequant.h
csrc/quantization/gptq_marlin/dequant.h
+286
-70
csrc/quantization/gptq_marlin/generate_kernels.py
csrc/quantization/gptq_marlin/generate_kernels.py
+11
-2
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+65
-16
csrc/quantization/gptq_marlin/kernel.h
csrc/quantization/gptq_marlin/kernel.h
+8
-7
csrc/quantization/gptq_marlin/marlin_template.h
csrc/quantization/gptq_marlin/marlin_template.h
+85
-36
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-2
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+103
-32
tests/kernels/quantization/test_marlin_gemm.py
tests/kernels/quantization/test_marlin_gemm.py
+41
-73
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-8
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+11
-3
vllm/model_executor/layers/quantization/hqq_marlin.py
vllm/model_executor/layers/quantization/hqq_marlin.py
+3
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+80
-4
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+14
-2
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
...el_executor/layers/quantization/utils/marlin_utils_fp4.py
+277
-0
No files found.
csrc/core/scalar_type.hpp
View file @
d74e5f37
...
...
@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static
inline
constexpr
auto
kU8
=
ScalarType
::
uint
(
8
);
static
inline
constexpr
auto
kU8B128
=
ScalarType
::
uint
(
8
,
128
);
static
inline
constexpr
auto
kFE2M1f
=
ScalarType
::
float_
(
2
,
1
,
true
,
ScalarType
::
NAN_NONE
);
static
inline
constexpr
auto
kFE3M2f
=
ScalarType
::
float_
(
3
,
2
,
true
,
ScalarType
::
NAN_NONE
);
static
inline
constexpr
auto
kFE4M3fn
=
...
...
@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
static
inline
constexpr
auto
kUint8
=
kU8
;
static
inline
constexpr
auto
kUint8b128
=
kU8B128
;
static
inline
constexpr
auto
kFloat4_e2m1f
=
kFE2M1f
;
static
inline
constexpr
auto
kFloat6_e3m2f
=
kFE3M2f
;
static
inline
constexpr
auto
kFloat8_e4m3fn
=
kFE4M3fn
;
static
inline
constexpr
auto
kFloat8_e5m2
=
kFE5M2
;
...
...
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
d74e5f37
...
...
@@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
]
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
,
"vllm::kFE2M1f"
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
)]
THREAD_M_BLOCKS
=
[
0.5
,
1
,
2
,
3
,
4
]
...
...
@@ -39,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS
=
[
0
,
-
1
,
2
,
4
,
8
]
GROUP_BLOCKS
=
[
0
,
-
1
,
1
,
2
,
4
,
8
]
DTYPES
=
[
"fp16"
,
"bf16"
]
...
...
@@ -72,6 +75,12 @@ def generate_new_kernels():
# for fp8
if
scalar_type
==
"vllm::kFE4M3fn"
and
group_blocks
not
in
[
-
1
,
8
]:
continue
# nvfp4 only supports group_size == 16
if
scalar_type
==
"vllm::kFE2M1f"
and
group_blocks
not
in
[
1
,
2
]:
continue
# other quantization methods don't support group_size = 16
if
scalar_type
!=
"vllm::kFE2M1f"
and
group_blocks
==
1
:
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
...
...
csrc/moe/marlin_moe_wna16/kernel.h
View file @
d74e5f37
...
...
@@ -7,17 +7,18 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace
MARLIN_NAMESPACE_NAME
{
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
d74e5f37
...
...
@@ -301,9 +301,11 @@ __global__ void Marlin(
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
uint16_t
*
__restrict__
scale2_ptr
,
// fp16 global scale (for nvfp4
// only)
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int32_t
*
__restrict__
sorted_token_ids_ptr
,
// moe sorted_ids
const
int32_t
*
__restrict__
expert_ids_ptr
,
// moe expert ids
const
int32_t
*
__restrict__
num_tokens_past_padded_ptr
,
// moe num tokens
...
...
@@ -341,6 +343,16 @@ __global__ void Marlin(
extern
__shared__
int4
sh
[];
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
bool
has_zp
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
;
constexpr
bool
is_int_type
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
||
w_type
==
vllm
::
kU4B8
||
w_type
==
vllm
::
kU8B128
;
// see comments of dequant.h for more details
constexpr
bool
dequant_skip_flop
=
!
is_int_type
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
w_type
==
vllm
::
kU8
);
scalar_t2
global_scale
;
constexpr
bool
has_act_order
=
group_blocks
==
0
;
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
...
...
@@ -348,7 +360,8 @@ __global__ void Marlin(
constexpr
int
moe_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
const
int
group_size
=
(
!
has_act_order
&&
group_blocks
==
-
1
)
?
prob_k
:
prob_k
/
num_groups
;
const
int
scales_expert_stride
=
prob_n
*
prob_k
/
group_size
/
8
;
const
int
scales_expert_stride
=
prob_n
*
prob_k
/
group_size
/
(
w_type
==
vllm
::
kFE2M1f
?
16
:
8
);
const
int
zp_expert_stride
=
is_zp_float
?
prob_n
*
prob_k
/
group_size
/
8
:
prob_n
*
prob_k
/
group_size
/
(
pack_factor
*
4
);
...
...
@@ -460,9 +473,16 @@ __global__ void Marlin(
if
(
mul_topk_weights
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
sh_block_topk_weights
[
tid4
*
4
+
i
]
=
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
tid4
*
4
+
i
]]));
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
sh_block_topk_weights
[
tid4
*
4
+
i
]
=
__hmul2
(
global_scale
,
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
tid4
*
4
+
i
]])));
}
else
{
sh_block_topk_weights
[
tid4
*
4
+
i
]
=
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
tid4
*
4
+
i
]]));
}
}
}
}
...
...
@@ -493,6 +513,11 @@ __global__ void Marlin(
expert_id
=
expert_ids_ptr
[
block_id
];
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
uint16_t
val
=
scale2_ptr
[
expert_id
];
global_scale
=
Dtype
::
num2num2
(
*
reinterpret_cast
<
scalar_t
*>
(
&
val
));
}
B_expert_off
=
expert_id
*
prob_n
*
prob_k
/
(
pack_factor
*
4
);
scales_ptr
+=
(
expert_id
-
old_expert_id
)
*
scales_expert_stride
;
if
constexpr
(
has_zp
)
{
...
...
@@ -606,7 +631,7 @@ __global__ void Marlin(
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
?
thread_k_blocks
/
group_blocks
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
...
...
@@ -664,7 +689,8 @@ __global__ void Marlin(
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
...
...
@@ -688,10 +714,20 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
if
constexpr
(
group_blocks
!=
-
1
&&
w_type
==
vllm
::
kFE2M1f
)
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
has_zp
))
s_sh_rd
=
s_sh_rd
*
2
+
warp_row
%
2
;
}
else
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
(
has_zp
&&
!
dequant_skip_flop
)))
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
8
;
else
...
...
@@ -801,7 +837,7 @@ __global__ void Marlin(
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
act_s_max_num_groups
)
{
if
(
sh_num_groups
>
act_s_max_num_groups
)
{
sh_num_groups
=
act_s_max_num_groups
;
}
...
...
@@ -1021,12 +1057,19 @@ __global__ void Marlin(
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int
cur_group_id
=
k_blocks
/
(
group_blocks
*
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
));
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
if
constexpr
(
w_type_id
!=
vllm
::
kFE2M1f
.
id
())
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
else
{
reinterpret_cast
<
int2
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
reinterpret_cast
<
int2
*>
(
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
}
}
}
...
...
@@ -1199,22 +1242,7 @@ __global__ void Marlin(
};
auto
dequant_data
=
[
&
](
int
q
,
scalar_t2
*
frag_b_ptr
)
{
if
constexpr
(
has_zp
&&
is_zp_float
||
!
has_zp
)
{
dequant
<
scalar_t2
,
w_type_id
>
(
q
,
frag_b_ptr
);
}
else
{
static_assert
(
has_zp
&&
!
is_zp_float
);
static_assert
(
w_type_id
==
vllm
::
kU4
.
id
()
||
w_type_id
==
vllm
::
kU8
.
id
());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if
constexpr
(
w_type_id
==
vllm
::
kU4
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU4B8
.
id
()
>
(
q
,
frag_b_ptr
);
}
else
if
constexpr
(
w_type_id
==
vllm
::
kU8
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU8B128
.
id
()
>
(
q
,
frag_b_ptr
);
}
}
dequant
<
scalar_t2
,
w_type_id
,
dequant_skip_flop
>
(
q
,
frag_b_ptr
);
};
// Execute the actual tensor core matmul of a sub-tile.
...
...
@@ -1244,13 +1272,23 @@ __global__ void Marlin(
dequant_data
(
zp_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
)
+
2
);
}
}
if
constexpr
(
has_zp
&&
is_zp_float
)
{
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
is_zp_float
)
{
if
(
is_new_zp
)
{
reinterpret_cast
<
int4
*>
(
&
frag_zp
)[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k2
])[
0
];
}
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
]));
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
])
+
2
);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
...
...
@@ -1259,7 +1297,10 @@ __global__ void Marlin(
FragB
frag_b1
;
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
if
constexpr
(
w_type_id
==
vllm
::
kFE2M1f
.
id
())
{
b_quant_1
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_0
=
b_quant_1
<<
8
;
}
else
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
...
...
@@ -1272,6 +1313,11 @@ __global__ void Marlin(
dequant_data
(
b_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b0
));
dequant_data
(
b_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b1
));
if
constexpr
(
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
static_assert
(
group_blocks
!=
-
1
);
...
...
@@ -1279,7 +1325,8 @@ __global__ void Marlin(
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
0
);
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
}
else
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
int
idx
=
(
threadIdx
.
x
/
4
)
%
2
;
scalar_t2
s2
=
Dtype
::
nums2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
[
j
/
2
][
j
%
2
*
2
+
0
])[
idx
],
...
...
@@ -1287,7 +1334,7 @@ __global__ void Marlin(
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
s2
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
s2
.
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
s2
.
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
}
else
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
...
...
@@ -1554,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
4
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
if
(
!
mul_topk_weights
)
{
res
=
__hmul2
(
res
,
global_scale
);
}
}
if
constexpr
(
m_block_size_8
)
{
((
scalar_t
*
)
sh_red
)[
idx
]
=
res
.
x
;
((
scalar_t
*
)
sh_red
)[
idx
+
8
*
c_sh_stride
]
=
res
.
y
;
...
...
@@ -1648,7 +1702,9 @@ __global__ void Marlin(
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_col_zp_to_shared
();
fetch_col_scale_to_shared
();
if
constexpr
(
!
dequant_skip_flop
)
{
fetch_col_scale_to_shared
();
}
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
,
i
);
...
...
@@ -1737,7 +1793,8 @@ __global__ void Marlin(
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
...
...
@@ -1747,7 +1804,8 @@ __global__ void Marlin(
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -1771,7 +1829,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
8
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
...
...
csrc/moe/marlin_moe_wna16/ops.cu
View file @
d74e5f37
...
...
@@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
...
...
@@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
...
...
@@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
BIGGROUP_GET_IF
(
vllm
::
kFE4M3fn
)
FP4_GET_IF
(
vllm
::
kFE2M1f
)
ACT_GET_IF
(
vllm
::
kU4B8
)
ACT_GET_IF
(
vllm
::
kU8B128
)
...
...
@@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
void
*
s2
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
void
*
sorted_token_ids
,
void
*
expert_ids
,
void
*
num_tokens_past_padded
,
void
*
topk_weights
,
int
moe_block_size
,
int
top_k
,
bool
mul_topk_weights
,
bool
is_ep
,
...
...
@@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
bool
m_block_size_8
=
moe_block_size
==
8
;
if
(
has_zp
)
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4
,
"q_type must be u4 when has_zp = True. Got = "
,
q_type
.
str
());
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
,
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = "
,
q_type
.
str
());
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
||
q_type
==
vllm
::
kFE2M1f
,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
...
...
@@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
uint16_t
*
s2_ptr
=
(
const
uint16_t
*
)
s2
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
...
...
@@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel
<<<
blocks
,
num_threads
,
max_shared_mem
,
stream
>>>
(
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
zp_ptr
,
g_idx_ptr
,
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
s2_ptr
,
zp_ptr
,
g_idx_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_past_padded_ptr
,
topk_weights_ptr
,
top_k
,
mul_topk_weights
,
is_ep
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
,
use_atomic_add
,
use_fp32_reduce
,
max_shared_mem
);
...
...
@@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch
::
Tensor
moe_wna16_marlin_gemm
(
torch
::
Tensor
&
a
,
std
::
optional
<
torch
::
Tensor
>
const
&
c_or_none
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
global_scale_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
g_idx_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
perm_or_none
,
torch
::
Tensor
&
workspace
,
...
...
@@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
}
}
torch
::
Tensor
global_scale
;
if
(
global_scale_or_none
.
has_value
())
{
global_scale
=
global_scale_or_none
.
value
();
TORCH_CHECK
(
b_q_type
==
vllm
::
kFE2M1f
,
"global_scale can only be used for float4_e2m1f."
);
}
else
{
global_scale
=
torch
::
empty
({
0
},
options
);
TORCH_CHECK
(
!
(
b_q_type
==
vllm
::
kFE2M1f
),
"the global_scale parameter must be passed for float4_e2m1f."
);
}
torch
::
Tensor
b_zeros
;
if
(
b_zeros_or_none
.
has_value
())
{
b_zeros
=
b_zeros_or_none
.
value
();
...
...
@@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm(
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
b_q_type
==
vllm
::
kU4
||
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4
or u8
when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
||
b_q_type
==
vllm
::
kFE4M3fn
,
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = "
,
b_q_type
==
vllm
::
kFE4M3fn
||
b_q_type
==
vllm
::
kFE2M1f
,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = "
,
b_q_type
.
str
());
}
...
...
@@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Half
>
();
}
MARLIN_NAMESPACE_NAME
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
sorted_token_ids
.
data_ptr
(),
expert_ids
.
data_ptr
(),
num_tokens_past_padded
.
data_ptr
(),
...
...
@@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
BFloat16
>
();
}
MARLIN_NAMESPACE_NAME
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
sorted_token_ids
.
data_ptr
(),
expert_ids
.
data_ptr
(),
num_tokens_past_padded
.
data_ptr
(),
topk_weights
.
data_ptr
(),
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
size_m
,
size_n
,
size_k
,
...
...
csrc/moe/torch_bindings.cpp
View file @
d74e5f37
...
...
@@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
...
...
csrc/quantization/gptq_marlin/dequant.h
View file @
d74e5f37
/*
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
The process of fast dequantization can be summarized as a combination
of bitwise operations and floating-point computations:
weight =>(bit_op / bitwise operations)=>
f16_value =>(flop / floating-point computation)=>
dequantized_weight
Since the dequantized weights typically require subtracting the zero point and
applying a scale factor, the floating-point computation step can be fused with
the zero-point subtraction and scaling operations.
The following are the parts that need to be modified for the fused operation
of zero-point subtraction and scaling.
## INT4 => FP16/BF16 or INT8 => FP16
The floating-point computation is `__hsub2`
If has zero points:
flop(bit_op(weight)) - flop(bit_op(zp))
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
= bit_op(weight) - bit_op(zp)
so we don't need additional modification.
If has float zero points:
flop(bit_op(weight)) - fzp
= sub(bit_op(weight), bias) - fzp
= bit_op(weight) - (fzp + bias)
where the `fzp + bias` can be computed at weight loading. But this
may have accuracy issue, so we should not use this in most cases.
If has not zero points:
scale(flop(bit_op(weight)))
= scale(sub(bit_op(weight), bias))
= scale(bit_op(weight)) - scale(bias)
= fma(bit_op(weight), scale_factor, scale(bias))
where the `scale(bias)` can be cached. But this may have accuracy issue,
so we should not use this in most cases.
## INT8 => BF16
INT8 => BF16 is a special case, it use byte_perm instead of flop.
We cannot fused byte_perm with scaling.
## FP4/FP8 => FP16/BF16
scale(flop(bit_op(weight)))
= scale(mul(bit_op(weight), multiplier))
= mul(bit_op(weight), scale_factor * multiplier)
where `scale_factor * multiplier` can be computed at weight loading.
*/
#include "marlin_dtypes.cuh"
...
...
@@ -27,7 +91,8 @@ __device__ inline uint32_t prmt(uint32_t a) {
return
res
;
}
template
<
typename
scalar_t2
,
vllm
::
ScalarTypeId
w_type_id
>
template
<
typename
scalar_t2
,
vllm
::
ScalarTypeId
w_type_id
,
bool
skip_flop
=
false
>
__device__
inline
void
dequant
(
int
q
,
scalar_t2
*
frag_b
);
//
...
...
@@ -40,7 +105,22 @@ __device__ inline void dequant(int q, scalar_t2* frag_b);
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4B8
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4B8
.
id
(),
true
>
(
int
q
,
half2
*
frag_b
)
{
const
int
MASK
=
0x000f000f
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
frag_b
[
0
]
=
*
reinterpret_cast
<
half2
*>
(
&
lo
);
frag_b
[
1
]
=
*
reinterpret_cast
<
half2
*>
(
&
hi
);
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4B8
.
id
(),
false
>
(
int
q
,
half2
*
frag_b
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
...
...
@@ -62,7 +142,14 @@ __device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4
.
id
(),
true
>
(
int
q
,
half2
*
frag_b
)
{
dequant
<
half2
,
vllm
::
kU4B8
.
id
(),
true
>
(
q
,
frag_b
);
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU4
.
id
(),
false
>
(
int
q
,
half2
*
frag_b
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
...
...
@@ -84,7 +171,7 @@ __device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4B8
.
id
()
>
(
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4B8
.
id
()
,
true
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
...
...
@@ -96,39 +183,36 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
// clang-format on
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
);
frag_b
[
1
]
=
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
);
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4B8
.
id
(),
false
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
dequant
<
nv_bfloat162
,
vllm
::
kU4B8
.
id
(),
true
>
(
q
,
frag_b
);
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
static
constexpr
uint32_t
SUB
=
0x43084308
;
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
SUB
));
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
SUB
));
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4
.
id
()
>
(
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4
.
id
()
,
true
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
dequant
<
nv_bfloat162
,
vllm
::
kU4B8
.
id
(),
true
>
(
q
,
frag_b
)
;
}
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
// clang-format on
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU4
.
id
(),
false
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
dequant
<
nv_bfloat162
,
vllm
::
kU4
.
id
(),
true
>
(
q
,
frag_b
);
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
static
constexpr
uint32_t
SUB
=
0x43004300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
SUB
));
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
],
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
SUB
));
}
//
...
...
@@ -140,8 +224,8 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8B128
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8B128
.
id
()
,
true
>
(
int
q
,
half2
*
frag_b
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
...
...
@@ -149,33 +233,42 @@ __device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
frag_b
[
0
]
=
*
reinterpret_cast
<
half2
*>
(
&
lo
);
frag_b
[
1
]
=
*
reinterpret_cast
<
half2
*>
(
&
hi
);
}
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8B128
.
id
(),
false
>
(
int
q
,
half2
*
frag_b
)
{
dequant
<
half2
,
vllm
::
kU8B128
.
id
(),
true
>
(
q
,
frag_b
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
)
,
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
]
,
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8
.
id
()
,
true
>
(
int
q
,
half2
*
frag_b
)
{
dequant
<
half2
,
vllm
::
kU8B128
.
id
(),
true
>
(
q
,
frag_b
)
;
}
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kU8
.
id
(),
false
>
(
int
q
,
half2
*
frag_b
)
{
dequant
<
half2
,
vllm
::
kU8
.
id
(),
true
>
(
q
,
frag_b
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64006400
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
frag_b
[
0
]
=
__hsub2
(
frag_b
[
0
],
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
)
,
frag_b
[
1
]
=
__hsub2
(
frag_b
[
1
]
,
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU8B128
.
id
()
>
(
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU8B128
.
id
()
,
false
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
...
...
@@ -200,7 +293,7 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU8
.
id
()
>
(
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kU8
.
id
()
,
false
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
...
...
@@ -225,22 +318,30 @@ __device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kFE4M3fn
.
id
()
>
(
int
q
,
half2
*
frag_b
)
{
__device__
inline
void
dequant
<
half2
,
vllm
::
kFE4M3fn
.
id
()
,
true
>
(
int
q
,
half2
*
frag_b
)
{
// Constants for FP8 (E4M3) and FP16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
FP16_EXPONENT
=
5
;
constexpr
int
FP8_EXPONENT
=
4
,
FP16_EXPONENT
=
5
;
constexpr
int
RIGHT_SHIFT
=
FP16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
constexpr
int
MASK
=
0x7F007F00
;
// Extract and shift FP8 values to FP16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
q
<<=
8
;
int
Out2
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
);
frag_b
[
0
]
=
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
);
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kFE4M3fn
.
id
(),
false
>
(
int
q
,
half2
*
frag_b
)
{
dequant
<
half2
,
vllm
::
kFE4M3fn
.
id
(),
true
>
(
q
,
frag_b
);
// Constants for FP8 (E4M3) and FP16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP16_EXPONENT
=
5
;
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
...
...
@@ -248,28 +349,36 @@ __device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
const
half2
bias_reg
=
__float2half2_rn
(
float
(
1
<<
BIAS_OFFSET
));
// Convert to half2 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
),
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
),
bias_reg
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
bias_reg
);
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kFE4M3fn
.
id
()
>
(
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kFE4M3fn
.
id
()
,
true
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
// Constants for FP8 (E4M3) and BF16 formats
constexpr
int
FP8_EXPONENT
=
4
,
FP8_MANTISSA
=
3
,
BF16_EXPONENT
=
8
;
constexpr
int
FP8_EXPONENT
=
4
,
BF16_EXPONENT
=
8
;
constexpr
int
RIGHT_SHIFT
=
BF16_EXPONENT
-
FP8_EXPONENT
;
// Calculate MASK for extracting mantissa and exponent
constexpr
int
MASK1
=
0x80000000
;
constexpr
int
MASK2
=
MASK1
>>
(
FP8_EXPONENT
+
FP8_MANTISSA
);
constexpr
int
MASK3
=
MASK2
&
0x7fffffff
;
constexpr
int
MASK
=
MASK3
|
(
MASK3
>>
16
);
// Final MASK value: 0x7F007F00
constexpr
int
MASK
=
0x7F007F00
;
// Extract and shift FP8 values to BF16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
int
Out2
=
((
q
<<
8
)
&
0x80008000
)
|
(((
q
<<
8
)
&
MASK
)
>>
RIGHT_SHIFT
);
q
<<=
8
;
int
Out2
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
);
frag_b
[
0
]
=
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
);
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kFE4M3fn
.
id
(),
false
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
dequant
<
nv_bfloat162
,
vllm
::
kFE4M3fn
.
id
(),
true
>
(
q
,
frag_b
);
// Constants for FP8 (E4M3) and BF16 formats
constexpr
int
FP8_EXPONENT
=
4
,
BF16_EXPONENT
=
8
;
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
...
...
@@ -281,9 +390,116 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
__float2bfloat162_rn
(
*
reinterpret_cast
<
const
float
*>
(
&
BIAS
));
// Convert to bfloat162 and apply bias
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
bias_reg
);
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kFE2M1f
.
id
(),
true
>
(
int
q
,
half2
*
frag_b
)
{
// Constants for FP4 (E2M1) and FP16 formats
constexpr
int
FP4_EXPONENT
=
2
,
FP16_EXPONENT
=
5
;
constexpr
int
RIGHT_SHIFT
=
FP16_EXPONENT
-
FP4_EXPONENT
;
constexpr
int
MASK
=
0x70007000
;
// Extract and shift FP4 values to FP16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
q
<<=
4
;
int
Out2
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
);
frag_b
[
0
]
=
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
);
}
template
<
>
__device__
inline
void
dequant
<
half2
,
vllm
::
kFE2M1f
.
id
(),
false
>
(
int
q
,
half2
*
frag_b
)
{
dequant
<
half2
,
vllm
::
kFE2M1f
.
id
(),
true
>
(
q
,
frag_b
);
// Constants for FP4 (E2M1) and FP16 formats
constexpr
int
FP4_EXPONENT
=
2
,
FP16_EXPONENT
=
5
;
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
FP16_EXPONENT
-
1
))
-
(
1
<<
(
FP4_EXPONENT
-
1
));
const
half2
bias_reg
=
__float2half2_rn
(
float
(
1
<<
BIAS_OFFSET
));
// Convert to half2 and apply bias
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
bias_reg
);
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kFE2M1f
.
id
(),
true
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
// Constants for FP4 (E2M1) and FP16 formats
constexpr
int
FP4_EXPONENT
=
2
,
BF16_EXPONENT
=
8
;
constexpr
int
RIGHT_SHIFT
=
BF16_EXPONENT
-
FP4_EXPONENT
;
constexpr
int
MASK
=
0x70007000
;
// Extract and shift FP4 values to FP16 format
int
Out1
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
q
<<=
4
;
int
Out2
=
(
q
&
0x80008000
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
);
frag_b
[
0
]
=
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
);
}
template
<
>
__device__
inline
void
dequant
<
nv_bfloat162
,
vllm
::
kFE2M1f
.
id
(),
false
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
dequant
<
nv_bfloat162
,
vllm
::
kFE2M1f
.
id
(),
true
>
(
q
,
frag_b
);
// Constants for FP4 (E2M1) and BF16 formats
constexpr
int
FP4_EXPONENT
=
2
,
BF16_EXPONENT
=
8
;
// Construct and apply exponent bias
constexpr
int
BIAS_OFFSET
=
(
1
<<
(
BF16_EXPONENT
-
1
))
-
(
1
<<
(
FP4_EXPONENT
-
1
));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr
uint32_t
BIAS
=
(
BIAS_OFFSET
+
127
)
<<
23
;
const
nv_bfloat162
bias_reg
=
__float2bfloat162_rn
(
*
reinterpret_cast
<
const
float
*>
(
&
BIAS
));
// Convert to half2 and apply bias
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
bias_reg
);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
bias_reg
);
}
template
<
typename
scalar_t2
>
__device__
inline
void
dequant_fp8_scales
(
int
q
,
scalar_t2
*
frag_b
);
template
<
>
__device__
inline
void
dequant_fp8_scales
<
half2
>
(
int
q
,
half2
*
frag_b
)
{
int
Out1
=
(
q
&
0xFF00FF00
)
>>
1
;
;
q
<<=
8
;
int
Out2
=
(
q
&
0xFF00FF00
)
>>
1
;
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
*
reinterpret_cast
<
const
half2
*>
(
&
Out1
);
frag_b
[
0
]
=
*
reinterpret_cast
<
const
half2
*>
(
&
Out2
);
};
template
<
>
__device__
inline
void
dequant_fp8_scales
<
nv_bfloat162
>
(
int
q
,
nv_bfloat162
*
frag_b
)
{
constexpr
int
FP8_EXPONENT
=
4
,
BF16_EXPONENT
=
8
;
constexpr
int
RIGHT_SHIFT
=
BF16_EXPONENT
-
FP8_EXPONENT
;
constexpr
int
MASK
=
0x7F007F00
;
// Extract and shift FP8 values to BF16 format
int
Out1
=
((
q
&
0x80008000
)
>>
1
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
q
<<=
8
;
int
Out2
=
((
q
&
0x80008000
)
>>
1
)
|
((
q
&
MASK
)
>>
RIGHT_SHIFT
);
// Note: reverse indexing is intentional because weights are permuted
frag_b
[
1
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
)
,
bias_reg
)
;
frag_b
[
0
]
=
__hmul2
(
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
)
,
bias_reg
)
;
frag_b
[
1
]
=
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out1
);
frag_b
[
0
]
=
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
Out2
);
}
#endif
...
...
csrc/quantization/gptq_marlin/generate_kernels.py
View file @
d74e5f37
...
...
@@ -31,7 +31,10 @@ TEMPLATE = ("template __global__ void Marlin<"
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
]
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
,
"vllm::kFE2M1f"
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
),
(
128
,
64
,
128
)]
...
...
@@ -40,7 +43,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS
=
[
0
,
-
1
,
2
,
4
,
8
]
GROUP_BLOCKS
=
[
0
,
1
,
-
1
,
2
,
4
,
8
]
DTYPES
=
[
"fp16"
,
"bf16"
]
...
...
@@ -73,6 +76,12 @@ def generate_new_kernels():
# for fp8
if
scalar_type
==
"vllm::kFE4M3fn"
and
group_blocks
not
in
[
-
1
,
8
]:
continue
# nvfp4 only supports group_size == 16
if
scalar_type
==
"vllm::kFE2M1f"
and
group_blocks
!=
1
:
continue
# other quantization methods don't support group_size = 16
if
scalar_type
!=
"vllm::kFE2M1f"
and
group_blocks
==
1
:
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
d74e5f37
...
...
@@ -258,6 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
...
...
@@ -314,6 +315,23 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
...
...
@@ -366,6 +384,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
COMMON_GET_IF
(
vllm
::
kU4B8
)
COMMON_GET_IF
(
vllm
::
kU8B128
)
FP4_GET_IF
(
vllm
::
kFE2M1f
)
BIGGROUP_GET_IF
(
vllm
::
kFE4M3fn
)
ACT_GET_IF
(
vllm
::
kU4B8
)
...
...
@@ -434,8 +454,8 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
lda
,
void
*
workspace
,
void
*
s2
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
lda
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k_init
,
...
...
@@ -446,11 +466,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
,
"q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
"has_zp = False. Got = "
,
q_type
.
str
());
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
||
q_type
==
vllm
::
kFE2M1f
,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = "
,
q_type
.
str
());
}
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
...
...
@@ -483,6 +504,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
uint16_t
*
s2_ptr
=
(
const
uint16_t
*
)
s2
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
...
...
@@ -601,7 +623,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel
<<<
blocks
,
num_threads
,
max_shared_mem_new
,
stream
>>>
(
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
zp_ptr
,
g_idx_ptr
,
num_groups
,
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
s2_ptr
,
zp_ptr
,
g_idx_ptr
,
num_groups
,
prob_m_split
,
prob_n
,
prob_k
,
lda
,
locks
,
part_use_atomic_add
,
use_fp32_reduce
,
max_shared_mem_new
);
// clang-format on
...
...
@@ -617,6 +639,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
std
::
optional
<
torch
::
Tensor
>
c_or_none
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
global_scale_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
g_idx_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
perm_or_none
,
torch
::
Tensor
&
workspace
,
...
...
@@ -759,6 +782,17 @@ torch::Tensor gptq_marlin_gemm(
}
}
torch
::
Tensor
global_scale
;
if
(
global_scale_or_none
.
has_value
())
{
global_scale
=
global_scale_or_none
.
value
();
TORCH_CHECK
(
b_q_type
==
vllm
::
kFE2M1f
,
"global_scale can only be used for float4_e2m1f."
);
}
else
{
global_scale
=
torch
::
empty
({
0
},
options
);
TORCH_CHECK
(
!
(
b_q_type
==
vllm
::
kFE2M1f
),
"the global_scale parameter must be passed for float4_e2m1f."
);
}
torch
::
Tensor
b_zeros
;
if
(
b_zeros_or_none
.
has_value
())
{
b_zeros
=
b_zeros_or_none
.
value
();
...
...
@@ -774,8 +808,9 @@ torch::Tensor gptq_marlin_gemm(
"b_q_type must be u4 or u8 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
||
b_q_type
==
vllm
::
kFE4M3fn
,
"b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when "
b_q_type
==
vllm
::
kFE4M3fn
||
b_q_type
==
vllm
::
kFE2M1f
,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = "
,
b_q_type
.
str
());
}
...
...
@@ -820,22 +855,36 @@ torch::Tensor gptq_marlin_gemm(
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Half
>
();
}
marlin
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
a
.
stride
(
0
),
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
BFloat16
>
();
}
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
a
.
stride
(
0
),
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
a
.
stride
(
0
),
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
{
...
...
csrc/quantization/gptq_marlin/kernel.h
View file @
d74e5f37
...
...
@@ -7,13 +7,14 @@
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
int prob_k, int lda, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
namespace
MARLIN_NAMESPACE_NAME
{
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
...
...
csrc/quantization/gptq_marlin/marlin_template.h
View file @
d74e5f37
...
...
@@ -292,9 +292,11 @@ __global__ void Marlin(
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
uint16_t
*
__restrict__
scale2_ptr
,
// fp16 global scale (for nvfp4
// only)
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
...
...
@@ -325,6 +327,21 @@ __global__ void Marlin(
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
bool
has_zp
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
;
constexpr
bool
is_int_type
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
||
w_type
==
vllm
::
kU4B8
||
w_type
==
vllm
::
kU8B128
;
// see comments of dequant.h for more details
constexpr
bool
dequant_skip_flop
=
!
is_int_type
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
w_type
==
vllm
::
kU8
);
scalar_t2
global_scale
;
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
uint16_t
val
=
scale2_ptr
[
0
];
global_scale
=
Dtype
::
num2num2
(
*
reinterpret_cast
<
scalar_t
*>
(
&
val
));
}
constexpr
bool
has_act_order
=
group_blocks
==
0
;
constexpr
int
m_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
...
...
@@ -481,7 +498,7 @@ __global__ void Marlin(
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
?
thread_k_blocks
/
group_blocks
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
...
...
@@ -540,7 +557,8 @@ __global__ void Marlin(
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
...
...
@@ -564,10 +582,20 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
if
constexpr
(
group_blocks
!=
-
1
&&
w_type
==
vllm
::
kFE2M1f
)
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
has_zp
))
s_sh_rd
=
s_sh_rd
*
2
+
warp_row
%
2
;
}
else
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
(
has_zp
&&
!
dequant_skip_flop
)))
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
8
;
else
...
...
@@ -681,7 +709,7 @@ __global__ void Marlin(
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
act_s_max_num_groups
)
{
if
(
sh_num_groups
>
act_s_max_num_groups
)
{
sh_num_groups
=
act_s_max_num_groups
;
}
...
...
@@ -887,12 +915,19 @@ __global__ void Marlin(
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int
cur_group_id
=
k_blocks
/
(
group_blocks
*
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
));
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
if
constexpr
(
w_type_id
!=
vllm
::
kFE2M1f
.
id
())
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
else
{
reinterpret_cast
<
int2
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
reinterpret_cast
<
int2
*>
(
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
}
}
}
...
...
@@ -1065,22 +1100,7 @@ __global__ void Marlin(
};
auto
dequant_data
=
[
&
](
int
q
,
scalar_t2
*
frag_b_ptr
)
{
if
constexpr
(
has_zp
&&
is_zp_float
||
!
has_zp
)
{
dequant
<
scalar_t2
,
w_type_id
>
(
q
,
frag_b_ptr
);
}
else
{
static_assert
(
has_zp
&&
!
is_zp_float
);
static_assert
(
w_type_id
==
vllm
::
kU4
.
id
()
||
w_type_id
==
vllm
::
kU8
.
id
());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if
constexpr
(
w_type_id
==
vllm
::
kU4
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU4B8
.
id
()
>
(
q
,
frag_b_ptr
);
}
else
if
constexpr
(
w_type_id
==
vllm
::
kU8
.
id
())
{
dequant
<
scalar_t2
,
vllm
::
kU8B128
.
id
()
>
(
q
,
frag_b_ptr
);
}
}
dequant
<
scalar_t2
,
w_type_id
,
dequant_skip_flop
>
(
q
,
frag_b_ptr
);
};
// Execute the actual tensor core matmul of a sub-tile.
...
...
@@ -1110,13 +1130,23 @@ __global__ void Marlin(
dequant_data
(
zp_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
)
+
2
);
}
}
if
constexpr
(
has_zp
&&
is_zp_float
)
{
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
is_zp_float
)
{
if
(
is_new_zp
)
{
reinterpret_cast
<
int4
*>
(
&
frag_zp
)[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k2
])[
0
];
}
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
]));
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
])
+
2
);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
...
...
@@ -1125,7 +1155,10 @@ __global__ void Marlin(
FragB
frag_b1
;
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
if
constexpr
(
w_type_id
==
vllm
::
kFE2M1f
.
id
())
{
b_quant_1
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_0
=
b_quant_1
<<
8
;
}
else
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
...
...
@@ -1138,6 +1171,11 @@ __global__ void Marlin(
dequant_data
(
b_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b0
));
dequant_data
(
b_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b1
));
if
constexpr
(
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
static_assert
(
group_blocks
!=
-
1
);
...
...
@@ -1145,7 +1183,8 @@ __global__ void Marlin(
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
0
);
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
}
else
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
int
idx
=
(
threadIdx
.
x
/
4
)
%
2
;
scalar_t2
s2
=
Dtype
::
nums2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
[
j
/
2
][
j
%
2
*
2
+
0
])[
idx
],
...
...
@@ -1153,7 +1192,7 @@ __global__ void Marlin(
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
s2
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
s2
.
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
s2
.
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_zp
&&
group_blocks
!=
-
1
)
{
}
else
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
...
...
@@ -1408,10 +1447,15 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
4
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
res
=
__hmul2
(
res
,
global_scale
);
}
if
constexpr
(
m_block_size_8
)
{
((
scalar_t
*
)
sh_red
)[
idx
]
=
res
.
x
;
((
scalar_t
*
)
sh_red
)[
idx
+
8
*
c_sh_stride
]
=
res
.
y
;
...
...
@@ -1488,7 +1532,9 @@ __global__ void Marlin(
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_col_zp_to_shared
();
fetch_col_scale_to_shared
();
if
constexpr
(
!
dequant_skip_flop
)
{
fetch_col_scale_to_shared
();
}
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
...
...
@@ -1563,7 +1609,8 @@ __global__ void Marlin(
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
...
...
@@ -1573,7 +1620,8 @@ __global__ void Marlin(
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -1597,7 +1645,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
8
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
...
...
csrc/torch_bindings.cpp
View file @
d74e5f37
...
...
@@ -292,8 +292,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops
.
def
(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor?
g_idx_or_none,
"
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"Tensor b_scales, Tensor?
global_scale, Tensor?
b_zeros_or_none, Tensor? "
"
g_idx_or_none,
Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"
,
{
stride_tag
});
...
...
tests/kernels/moe/test_moe.py
View file @
d74e5f37
...
...
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
rand_marlin_weight_fp4_like
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
...
...
@@ -286,21 +288,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
atol
=
mixtral_moe_tol
[
dtype
])
def
marlin_moe_generate_valid_test_cases
():
import
itertools
m_list
=
[
1
,
123
,
666
]
n_list
=
[
128
,
1024
]
k_list
=
[
256
,
2048
]
e_list
=
[
4
,
12
]
topk_list
=
[
2
,
3
]
ep_size_list
=
[
1
,
4
]
dtype_list
=
[
torch
.
half
,
torch
.
bfloat16
]
group_size_list
=
[
-
1
,
16
,
32
,
128
]
act_order_list
=
[
True
,
False
]
quant_type_list
=
[
scalar_types
.
float4_e2m1f
,
scalar_types
.
float8_e4m3fn
,
scalar_types
.
uint4
,
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
,
]
is_k_full_list
=
[
True
,
False
]
all_combinations
=
itertools
.
product
(
m_list
,
n_list
,
k_list
,
e_list
,
topk_list
,
ep_size_list
,
dtype_list
,
group_size_list
,
act_order_list
,
quant_type_list
,
is_k_full_list
)
def
is_invalid
(
m
,
n
,
k
,
e
,
topk
,
ep_size
,
dtype
,
group_size
,
act_order
,
quant_type
,
is_k_full
):
if
quant_type
==
scalar_types
.
float8_e4m3fn
and
\
group_size
not
in
[
-
1
,
128
]:
return
False
if
quant_type
==
scalar_types
.
float4_e2m1f
and
group_size
!=
16
:
return
False
if
quant_type
!=
scalar_types
.
float4_e2m1f
and
group_size
==
16
:
return
False
# Filter act_order
if
act_order
:
if
group_size
in
(
-
1
,
k
,
n
):
return
False
if
quant_type
not
in
[
scalar_types
.
uint4b8
]:
return
False
elif
not
is_k_full
:
return
False
return
True
cases
=
[]
for
case
in
all_combinations
:
if
is_invalid
(
*
case
):
cases
.
append
(
case
)
return
cases
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
123
,
666
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
256
,
2048
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
12
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
scalar_types
.
uint4
,
scalar_types
.
uint8b128
,
scalar_types
.
uint4b8
,
scalar_types
.
float8_e4m3fn
])
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
((
"m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"
),
marlin_moe_generate_valid_test_cases
())
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_fused_marlin_moe
(
m
:
int
,
...
...
@@ -338,6 +383,11 @@ def test_fused_marlin_moe(
if
not
is_k_full
:
return
if
quant_type
==
scalar_types
.
float4_e2m1f
and
group_size
!=
16
:
return
if
quant_type
!=
scalar_types
.
float4_e2m1f
and
group_size
==
16
:
return
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
20
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
20
...
...
@@ -355,12 +405,27 @@ def test_fused_marlin_moe(
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
global_scale1_l
=
[]
zeros1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
if
has_zp
:
if
quant_type
==
scalar_types
.
float4_e2m1f
:
w_ref1
,
qweight1
,
scales1
,
global_scale1
=
\
rand_marlin_weight_fp4_like
(
w1
[
i
],
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
global_scale1_l
.
append
(
global_scale1
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
w_ref1
,
qweight1
,
scales1
=
marlin_quant_fp8_torch
(
w1
[
i
],
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
elif
has_zp
:
w_ref1
,
qweight1
,
scales1
,
zeros1
=
awq_marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
...
...
@@ -368,7 +433,7 @@ def test_fused_marlin_moe(
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
zeros1_l
.
append
(
zeros1
)
el
if
quant_type
!=
scalar_types
.
float8_e4m3fn
:
el
se
:
test_perm
=
torch
.
randperm
(
k
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
\
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
...
...
@@ -379,16 +444,11 @@ def test_fused_marlin_moe(
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
else
:
w_ref1
,
qweight1
,
scales1
=
marlin_quant_fp8_torch
(
w1
[
i
],
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
global_scale1
=
stack_and_dev
(
global_scale1_l
)
if
global_scale1_l
else
None
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
zeros1
=
stack_and_dev
(
zeros1_l
)
if
zeros1_l
else
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
...
...
@@ -396,12 +456,27 @@ def test_fused_marlin_moe(
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
global_scale2_l
=
[]
zeros2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
if
has_zp
:
if
quant_type
==
scalar_types
.
float4_e2m1f
:
w_ref2
,
qweight2
,
scales2
,
global_scale2
=
\
rand_marlin_weight_fp4_like
(
w2
[
i
],
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
global_scale2_l
.
append
(
global_scale2
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
w_ref2
,
qweight2
,
scales2
=
marlin_quant_fp8_torch
(
w2
[
i
],
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
elif
has_zp
:
w_ref2
,
qweight2
,
scales2
,
zeros2
=
awq_marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
...
...
@@ -409,7 +484,7 @@ def test_fused_marlin_moe(
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
zeros2_l
.
append
(
zeros2
)
el
if
quant_type
!=
scalar_types
.
float8_e4m3fn
:
el
se
:
test_perm
=
torch
.
randperm
(
n
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
\
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
...
...
@@ -420,24 +495,18 @@ def test_fused_marlin_moe(
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
else
:
w_ref2
,
qweight2
,
scales2
=
marlin_quant_fp8_torch
(
w2
[
i
],
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
global_scale2
=
stack_and_dev
(
global_scale2_l
)
if
global_scale2_l
else
None
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
zeros2
=
stack_and_dev
(
zeros2_l
)
if
zeros2_l
else
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
,
topk
,
False
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
...
...
@@ -452,6 +521,8 @@ def test_fused_marlin_moe(
topk_ids
,
global_num_experts
=
e
,
expert_map
=
e_map
,
global_scale1
=
global_scale1
,
global_scale2
=
global_scale2
,
g_idx1
=
g_idx1
,
g_idx2
=
g_idx2
,
sort_indices1
=
sort_indices1
,
...
...
tests/kernels/quantization/test_marlin_gemm.py
View file @
d74e5f37
...
...
@@ -20,6 +20,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES
,
marlin_make_empty_g_idx
,
marlin_make_workspace_new
,
marlin_permute_scales
,
query_marlin_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
FP4_MARLIN_SUPPORTED_GROUP_SIZES
,
rand_marlin_weight_fp4_like
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
marlin_quant_fp8_torch
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
...
...
@@ -190,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
())
@
pytest
.
mark
.
parametrize
(
"group_size"
,
set
(
MARLIN_SUPPORTED_GROUP_SIZES
+
FP4_MARLIN_SUPPORTED_GROUP_SIZES
))
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
...
...
@@ -210,6 +213,7 @@ def test_gptq_marlin_gemm(
use_fp32_reduce
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
...
...
@@ -220,6 +224,8 @@ def test_gptq_marlin_gemm(
return
if
group_size
==
size_k
:
return
if
has_zp
:
return
if
size_k
%
group_size
!=
0
:
return
...
...
@@ -227,7 +233,15 @@ def test_gptq_marlin_gemm(
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
if
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
quant_type
==
scalar_types
.
float4_e2m1f
:
if
group_size
!=
16
or
act_order
:
return
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
=
rand_marlin_weight_fp4_like
(
b_weight
.
T
,
group_size
)
g_idx
=
None
sort_indices
=
None
marlin_zp
=
None
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]:
return
if
act_order
:
...
...
@@ -236,26 +250,39 @@ def test_gptq_marlin_gemm(
b_weight
.
T
,
group_size
)
g_idx
=
None
sort_indices
=
None
marlin_zp
=
None
marlin_s2
=
None
elif
has_zp
:
if
group_size
==
16
:
return
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b_weight
,
quant_type
,
group_size
)
g_idx
=
None
sort_indices
=
None
marlin_s2
=
None
else
:
if
group_size
==
16
:
return
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
quant_type
,
group_size
,
act_order
)
marlin_
zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
marlin_zp
=
None
marlin_
s2
=
None
workspace
=
marlin_make_workspace_new
(
w_ref
.
device
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_
id
x
,
sort_indices
,
workspace
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
]
,
a_input
.
shape
[
1
],
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
False
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
,
quant_type
.
id
,
a_input
.
shape
[
0
]
,
b_weight
.
shape
[
1
]
,
a_input
.
shape
[
1
],
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
False
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_s2
,
marlin_zp
,
g_idx
,
sort_indices
,
...
...
@@ -339,67 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
True
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_awq_marlin_gemm
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
,
use_fp32_reduce
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b_weight
,
quant_type
,
group_size
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
is_k_full
=
True
workspace
=
marlin_make_workspace_new
(
a_input
.
device
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
...
...
@@ -452,6 +418,7 @@ def test_hqq_marlin_gemm(
None
,
marlin_w_q
,
marlin_s
,
None
,
marlin_zp
,
g_idx
,
g_idx_sort_indices
,
...
...
@@ -564,6 +531,7 @@ def test_marlin_gemm_subset_input():
None
,
marlin_q_w
,
marlin_s
,
None
,
marlin_zp
,
g_idx
,
sort_indices
,
...
...
vllm/_custom_ops.py
View file @
d74e5f37
...
...
@@ -333,6 +333,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
c
:
Optional
[
torch
.
Tensor
],
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
global_scale
:
Optional
[
torch
.
Tensor
],
b_zeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
...
...
@@ -866,6 +867,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
c
:
Optional
[
torch
.
Tensor
],
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
global_scale
:
Optional
[
torch
.
Tensor
],
b_zeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
...
...
@@ -878,9 +880,10 @@ def gptq_marlin_gemm(a: torch.Tensor,
use_atomic_add
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
,
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
c
,
b_q_weight
,
b_scales
,
b_zeros
,
g_idx
,
perm
,
workspace
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
c
,
b_q_weight
,
b_scales
,
global_scale
,
b_zeros
,
g_idx
,
perm
,
workspace
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
...
...
@@ -1381,6 +1384,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def
moe_wna16_marlin_gemm
(
input
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
],
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
global_scale
:
Optional
[
torch
.
Tensor
],
b_qzeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
...
...
@@ -1395,11 +1399,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
use_fp32_reduce
:
bool
,
is_zp_float
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_moe_C
.
moe_wna16_marlin_gemm
(
input
,
output
,
b_qweight
,
b_scales
,
b_qzeros
,
g_idx
,
perm
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_past_padded
,
topk_weights
,
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
input
,
output
,
b_qweight
,
b_scales
,
global_scale
,
b_qzeros
,
g_idx
,
perm
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_past_padded
,
topk_weights
,
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
d74e5f37
...
...
@@ -25,6 +25,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
quant_type_id
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
global_scale1
:
Optional
[
torch
.
Tensor
]
=
None
,
global_scale2
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -64,11 +66,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
quant_type
=
ScalarType
.
from_id
(
quant_type_id
)
assert
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8b128
,
scalar_types
.
uint4b8
,
scalar_types
.
float8_e4m3fn
scalar_types
.
float8_e4m3fn
,
scalar_types
.
float4_e2m1f
]
int4_scalar_types
=
[
scalar_types
.
uint4
,
scalar_types
.
uint4b8
]
num_bits
=
4
if
quant_type
in
int4_scalar_types
else
8
bit4_scalar_types
=
[
scalar_types
.
uint4
,
scalar_types
.
uint4b8
,
scalar_types
.
float4_e2m1f
]
num_bits
=
4
if
quant_type
in
bit4_scalar_types
else
8
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
...
...
@@ -133,6 +137,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache1
,
w1
,
w1_scale
,
global_scale1
,
w1_zeros
,
g_idx1
,
sort_indices1
,
...
...
@@ -165,6 +170,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
intermediate_cache3
,
w2
,
w2_scale
,
global_scale2
,
w2_zeros
,
g_idx2
,
sort_indices2
,
...
...
@@ -202,6 +208,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
quant_type_id
:
int
,
global_num_experts
:
int
=
-
1
,
global_scale1
:
Optional
[
torch
.
Tensor
]
=
None
,
global_scale2
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
vllm/model_executor/layers/quantization/hqq_marlin.py
View file @
d74e5f37
...
...
@@ -304,8 +304,10 @@ class HQQMarlinMethod(LinearMethodBase):
marlin_out
=
ops
.
gptq_marlin_gemm
(
x
,
None
,
layer
.
marlin_qweight
,
scales
,
None
,
zeros
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
...
...
@@ -315,7 +317,7 @@ class HQQMarlinMethod(LinearMethodBase):
self
.
output_size_per_partition
,
self
.
input_size_per_partition
,
True
,
# is_k_full
Tru
e
,
#
has_zp
Fals
e
,
#
use atomic add
True
,
# use 32-bit reduce
True
,
# use float zp
)
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
d74e5f37
...
...
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
apply_fp4_marlin_linear
,
is_fp4_marlin_supported
,
prepare_fp4_layer_for_marlin
,
prepare_moe_fp4_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
@@ -24,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -196,7 +200,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
10
0
return
8
0
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
...
@@ -278,9 +282,15 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_nvfp4_supported
=
cutlass_fp4_supported
()
self
.
use_marlin
=
False
if
not
self
.
cutlass_nvfp4_supported
:
raise
ValueError
(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and above."
)
if
is_fp4_marlin_supported
():
self
.
use_marlin
=
True
else
:
raise
ValueError
(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
)
def
create_weights
(
self
,
...
...
@@ -392,12 +402,29 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer
.
weight_scale_swizzled
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
if
self
.
use_marlin
:
prepare_fp4_layer_for_marlin
(
layer
)
del
layer
.
alpha
del
layer
.
input_scale
del
layer
.
weight_scale_swizzled
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
return
apply_fp4_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale_2
=
layer
.
weight_scale_2
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
output_dtype
=
x
.
dtype
# for input only the contracting dimension has a constraint.
...
...
@@ -434,6 +461,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_nvfp4_supported
=
cutlass_fp4_supported
()
self
.
use_marlin
=
False
if
not
self
.
cutlass_nvfp4_supported
:
if
is_fp4_marlin_supported
():
self
.
use_marlin
=
True
else
:
raise
ValueError
(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
...
...
@@ -442,6 +479,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
raise
ValueError
(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
layer
.
num_experts
=
num_experts
layer
.
params_dtype
=
params_dtype
layer
.
quant_config
=
self
.
quant_config
weight_dtype
=
torch
.
uint8
weight_scale_dtype
=
torch
.
float8_e4m3fn
...
...
@@ -594,7 +633,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
)
return
if
self
.
use_marlin
:
prepare_moe_fp4_layer_for_marlin
(
layer
)
del
layer
.
g1_alphas
del
layer
.
g2_alphas
del
layer
.
w13_input_scale_quant
del
layer
.
w2_input_scale_quant
del
layer
.
w13_blockscale_swizzled
del
layer
.
w2_blockscale_swizzled
def
apply
(
self
,
...
...
@@ -614,6 +661,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
):
if
self
.
use_marlin
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_ids
,
global_scale1
=
layer
.
w13_weight_scale_2
,
global_scale2
=
layer
.
w2_weight_scale_2
,
quant_type_id
=
scalar_types
.
float4_e2m1f
.
id
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
not
apply_router_weight_on_input
,
(
"Router weight on input is not "
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
d74e5f37
...
...
@@ -33,7 +33,7 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def
query_marlin_supported_quant_types
(
has_zp
:
bool
,
has_zp
:
Optional
[
bool
]
=
None
,
include_fp_type
:
bool
=
True
,
device_capability
:
Optional
[
int
]
=
None
,
):
...
...
@@ -45,6 +45,16 @@ def query_marlin_supported_quant_types(
if
device_capability
<
80
:
return
[]
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if
has_zp
is
None
:
types0
=
query_marlin_supported_quant_types
(
False
,
include_fp_type
,
device_capability
)
types1
=
query_marlin_supported_quant_types
(
True
,
include_fp_type
,
device_capability
)
return
types0
+
types1
if
has_zp
:
# AWQ style, unsigned + runtime zero-point
return
[
scalar_types
.
uint4
]
...
...
@@ -52,7 +62,7 @@ def query_marlin_supported_quant_types(
# GPTQ style, unsigned + symmetric bias
res
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
if
include_fp_type
:
res
+=
[
scalar_types
.
float8_e4m3fn
]
res
+=
[
scalar_types
.
float8_e4m3fn
,
scalar_types
.
float4_e2m1f
]
return
res
...
...
@@ -394,6 +404,7 @@ def apply_gptq_marlin_linear(
None
,
weight
,
weight_scale
,
None
,
weight_zp
,
g_idx
,
g_idx_sort_indices
,
...
...
@@ -439,6 +450,7 @@ def apply_awq_marlin_linear(
None
,
weight
,
weight_scale
,
None
,
weight_zp
,
g_idx
,
g_idx_sort_indices
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
0 → 100644
View file @
d74e5f37
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
import
vllm._custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
USE_FP32_REDUCE_DEFAULT
,
marlin_make_workspace_new
,
marlin_permute_scales
,
should_use_atomic_add_reduce
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
FP4_MARLIN_SUPPORTED_GROUP_SIZES
=
[
16
]
logger
=
init_logger
(
__name__
)
def
is_fp4_marlin_supported
():
return
current_platform
.
has_device_capability
(
80
)
def
fp4_marlin_process_scales
(
marlin_scales
):
assert
(
marlin_scales
>=
0
).
all
()
# convert to half first, we would convert to fp8 later
marlin_scales
=
marlin_scales
.
to
(
torch
.
half
)
# 8 is the number of scale number using by one thread
marlin_scales
=
marlin_scales
.
view
(
marlin_scales
.
size
(
0
)
//
2
,
2
,
-
1
,
8
)
marlin_scales
=
marlin_scales
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
marlin_scales
.
size
(
0
)
*
2
,
-
1
)
# fit the layout of fp8 dequantization
marlin_scales
=
marlin_scales
.
view
(
-
1
,
4
)[:,
[
0
,
2
,
1
,
3
]].
view
(
marlin_scales
.
size
(
0
),
-
1
)
# We assume that weight_scale (FP8-S1E4M3) is always greater
# than or equal to 0. So we can convert
# (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
# After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
# when weight_scale > 0. This allows us to have an exponent bias
# closer to zero after dequantization.
marlin_scales
=
(
marlin_scales
*
(
2
**
7
)).
view
(
torch
.
int16
)
<<
1
marlin_scales
=
marlin_scales
.
view
(
torch
.
float8_e4m3fn
)
marlin_scales
=
marlin_scales
[:,
1
::
2
].
contiguous
()
return
marlin_scales
def
fp4_marlin_process_global_scale
(
global_scale
):
assert
global_scale
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]
fp4_exponent
=
2
if
global_scale
.
dtype
==
torch
.
half
:
target_exponent
=
5
elif
global_scale
.
dtype
==
torch
.
bfloat16
:
target_exponent
=
8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
exponent_bias
=
2
**
(
target_exponent
-
1
)
-
2
**
(
fp4_exponent
-
1
)
return
global_scale
*
(
2.0
**
(
exponent_bias
-
7
))
def
apply_fp4_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale_2
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
# For GPUs that lack FP4 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP4 quantization
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,
)
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
n
=
size_n
,
k
=
size_k
,
device
=
input
.
device
,
dtype
=
input
.
dtype
)
output
=
ops
.
gptq_marlin_gemm
(
a
=
reshaped_x
,
c
=
None
,
b_q_weight
=
weight
,
b_scales
=
weight_scale
,
global_scale
=
weight_scale_2
,
b_zeros
=
None
,
g_idx
=
None
,
perm
=
None
,
workspace
=
workspace
,
b_q_type
=
scalar_types
.
float4_e2m1f
,
size_m
=
reshaped_x
.
size
(
0
),
size_n
=
size_n
,
size_k
=
size_k
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
prepare_fp4_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
logger
.
warning_once
(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
param_dtype
=
layer
.
params_dtype
assert
layer
.
weight
.
shape
==
(
part_size_n
,
part_size_k
//
2
)
device
=
layer
.
weight
.
device
# WORKSPACE
layer
.
workspace
=
marlin_make_workspace_new
(
device
)
# WEIGHT
# Repack weights to marlin format
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
qweight
=
layer
.
weight
.
view
(
torch
.
int32
).
T
.
contiguous
()
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
4
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Permute scales
weight_scale
=
layer
.
weight_scale
.
T
.
to
(
param_dtype
)
weight_scale
=
marlin_permute_scales
(
s
=
weight_scale
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=
16
)
weight_scale
=
fp4_marlin_process_scales
(
weight_scale
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight_scale_2
=
layer
.
weight_scale_2
.
to
(
param_dtype
)
weight_scale_2
=
fp4_marlin_process_global_scale
(
weight_scale_2
)
layer
.
weight_scale_2
=
torch
.
nn
.
Parameter
(
weight_scale_2
,
requires_grad
=
False
)
return
def
prepare_moe_fp4_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
logger
.
warning_once
(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
e
=
layer
.
num_experts
k
=
layer
.
hidden_size
n
=
layer
.
intermediate_size_per_partition
# WORKSPACE
device
=
layer
.
w13_weight
.
device
param_dtype
=
layer
.
params_dtype
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
# WEIGHT
# Repack weights to marlin format
for
name
in
[
"w13_weight"
,
"w2_weight"
]:
weight
=
getattr
(
layer
,
name
)
tensor_list
=
[]
if
"w13"
in
name
:
size_n
,
size_k
=
n
*
2
,
k
else
:
size_n
,
size_k
=
k
,
n
assert
weight
.
shape
==
(
e
,
size_n
,
size_k
//
2
)
for
i
in
range
(
e
):
qweight
=
weight
[
i
].
view
(
torch
.
int32
).
T
.
contiguous
()
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
4
)
tensor_list
.
append
(
marlin_qweight
)
weight
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
weight
)
# WEIGHT SCALES
# Permute scales
for
name
in
[
"w13"
,
"w2"
]:
scales
=
getattr
(
layer
,
name
+
"_weight_scale"
).
to
(
param_dtype
)
global_scale
=
getattr
(
layer
,
name
+
"_weight_scale_2"
).
to
(
param_dtype
)
tensor_list
=
[]
if
"w13"
in
name
:
size_n
,
size_k
=
n
*
2
,
k
else
:
size_n
,
size_k
=
k
,
n
for
i
in
range
(
e
):
marlin_scales
=
marlin_permute_scales
(
s
=
scales
[
i
].
T
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=
16
)
marlin_scales
=
fp4_marlin_process_scales
(
marlin_scales
)
tensor_list
.
append
(
marlin_scales
)
scales
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
scales
=
torch
.
nn
.
Parameter
(
scales
,
requires_grad
=
False
)
setattr
(
layer
,
name
+
"_weight_scale"
,
scales
)
global_scale
=
fp4_marlin_process_global_scale
(
global_scale
)
global_scale
=
torch
.
nn
.
Parameter
(
global_scale
,
requires_grad
=
False
)
setattr
(
layer
,
name
+
"_weight_scale_2"
,
global_scale
)
def
rand_marlin_weight_fp4_like
(
weight
,
group_size
):
assert
group_size
>
0
size_n
,
size_k
=
weight
.
shape
device
=
weight
.
device
scales
=
weight
.
view
(
size_n
,
-
1
,
group_size
).
abs
().
max
(
-
1
)[
0
]
/
6
global_scale
=
scales
.
max
()
/
448
scales
=
(
scales
/
global_scale
).
to
(
torch
.
float8_e4m3fn
)
fp4_weight
=
torch
.
randint
(
0
,
256
,
(
size_n
,
size_k
//
2
),
dtype
=
torch
.
uint8
,
device
=
weight
.
device
)
fp4_weight_part_1
=
((
fp4_weight
&
0b10000000
)
|
((
fp4_weight
&
0b01110000
)
>>
2
))
fp4_weight_part_1
=
fp4_weight_part_1
.
view
(
torch
.
float8_e4m3fn
)
fp4_weight_part_1
=
fp4_weight_part_1
.
to
(
weight
.
dtype
)
*
(
2
**
6
)
fp4_weight2
=
fp4_weight
<<
4
fp4_weight_part_2
=
((
fp4_weight2
&
0b10000000
)
|
((
fp4_weight2
&
0b01110000
)
>>
2
))
fp4_weight_part_2
=
fp4_weight_part_2
.
view
(
torch
.
float8_e4m3fn
)
fp4_weight_part_2
=
fp4_weight_part_2
.
to
(
weight
.
dtype
)
*
(
2
**
6
)
weight_ref
=
torch
.
cat
(
[
fp4_weight_part_2
.
unsqueeze
(
2
),
fp4_weight_part_1
.
unsqueeze
(
2
)],
2
).
view
(
size_n
,
size_k
)
weight_ref
=
weight_ref
*
global_scale
.
to
(
weight
.
dtype
)
*
\
scales
.
repeat_interleave
(
group_size
,
1
).
to
(
weight
.
dtype
)
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
fp4_weight
.
view
(
torch
.
int32
).
T
.
contiguous
(),
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
4
,
)
marlin_scales
=
marlin_permute_scales
(
s
=
scales
.
T
.
to
(
weight
.
dtype
),
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=
group_size
)
marlin_scales
=
fp4_marlin_process_scales
(
marlin_scales
)
global_scale
=
fp4_marlin_process_global_scale
(
global_scale
)
return
weight_ref
.
T
,
marlin_qweight
,
marlin_scales
,
global_scale
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment