Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a985548
Commit
7a985548
authored
May 22, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.0' into v0.9.0-ori
parents
45d3785c
dc1440cf
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1081 additions
and
1255 deletions
+1081
-1255
csrc/cpu/quant.cpp
csrc/cpu/quant.cpp
+337
-4
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+36
-1
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+10
-0
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+14
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+4
-0
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
+0
-31
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
+0
-20
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
+0
-31
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
+0
-20
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
+0
-31
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
+0
-18
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+0
-588
csrc/moe/marlin_moe_wna16/.gitignore
csrc/moe/marlin_moe_wna16/.gitignore
+1
-0
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+22
-9
csrc/moe/marlin_moe_wna16/kernel.h
csrc/moe/marlin_moe_wna16/kernel.h
+16
-17
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+291
-281
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+205
-192
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+4
-4
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+133
-0
csrc/moe/moe_wna16_utils.h
csrc/moe/moe_wna16_utils.h
+8
-8
No files found.
Too many changes to show.
To preserve performance only
486 of 486+
files are displayed.
Plain diff
Email patch
csrc/cpu/quant.cpp
View file @
7a985548
...
...
@@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output,
}
}
template
<
bool
AZP
,
bool
PerChannel
,
bool
Bias
,
typename
scalar_t
>
void
dynamic_quant_epilogue
(
const
float
*
input
,
scalar_t
*
output
,
const
float
*
a_scale
,
const
float
*
b_scale
,
const
int32_t
*
azp
,
const
int32_t
*
azp_adj
,
const
scalar_t
*
bias
,
const
int
num_tokens
,
const
int
hidden_size
)
{
CPU_KERNEL_GUARD_IN
(
dynamic_quant_epilogue
)
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
azp_adj_load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
azp_adj_load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
int
j
=
0
;
cvt_vec_t
token_scale_vec
(
a_scale
[
i
]);
cvt_vec_t
token_zp_scale_vec
;
if
constexpr
(
AZP
)
{
float
zp_scale_val
=
a_scale
[
i
]
*
static_cast
<
float
>
(
azp
[
i
]);
if
constexpr
(
!
PerChannel
)
{
zp_scale_val
*=
*
b_scale
;
}
token_zp_scale_vec
=
cvt_vec_t
(
zp_scale_val
);
}
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
elems_fp32
=
elems_fp32
*
token_scale_vec
;
if
constexpr
(
AZP
)
{
azp_adj_load_vec_t
azp_adj_vec
(
azp_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
azp_adj_fp32
=
azp_adj_fp32
*
token_zp_scale_vec
;
if
constexpr
(
PerChannel
)
{
cvt_vec_t
b_scale_vec
(
b_scale
+
j
);
azp_adj_fp32
=
azp_adj_fp32
*
b_scale_vec
;
}
elems_fp32
=
elems_fp32
-
azp_adj_fp32
;
}
if
constexpr
(
Bias
)
{
load_vec_t
bias_vec
(
bias
+
j
);
cvt_vec_t
bias_vec_fp32
(
bias_vec
);
elems_fp32
=
elems_fp32
+
bias_vec_fp32
;
}
load_vec_t
elems_out
(
elems_fp32
);
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
);
}
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
elems_fp32
=
elems_fp32
*
token_scale_vec
;
if
constexpr
(
AZP
)
{
azp_adj_load_vec_t
azp_adj_vec
(
azp_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
azp_adj_fp32
=
azp_adj_fp32
*
token_zp_scale_vec
;
if
constexpr
(
PerChannel
)
{
cvt_vec_t
b_scale_vec
(
b_scale
+
j
);
azp_adj_fp32
=
azp_adj_fp32
*
b_scale_vec
;
}
elems_fp32
=
elems_fp32
-
azp_adj_fp32
;
}
if
constexpr
(
Bias
)
{
load_vec_t
bias_vec
(
bias
+
j
);
cvt_vec_t
bias_vec_fp32
(
bias_vec
);
elems_fp32
=
elems_fp32
+
bias_vec_fp32
;
}
load_vec_t
elems_out
(
elems_fp32
);
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
#elif defined(__powerpc64__)
template
<
bool
AZP
,
typename
scalar_t
>
void
static_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
const
float
*
scale
,
const
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
constexpr
float
i8_min
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
constexpr
float
i8_max
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
const
cvt_vec_t
inv_scale
(
1.0
/
*
scale
);
const
cvt_vec_t
i8_min_vec
(
i8_min
);
const
cvt_vec_t
i8_max_vec
(
i8_max
);
cvt_vec_t
zp_vec
;
if
constexpr
(
AZP
)
{
zp_vec
=
cvt_vec_t
(
static_cast
<
float
>
(
*
azp
));
}
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
int
j
=
0
;
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
elems_fp32
*
inv_scale
;
if
constexpr
(
AZP
)
{
elems_fp32
=
elems_fp32
+
zp_vec
;
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
}
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
elems_fp32
*
inv_scale
;
if
constexpr
(
AZP
)
{
elems_fp32
=
elems_fp32
+
zp_vec
;
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
template
<
bool
AZP
,
typename
scalar_t
>
void
dynamic_scaled_int8_quant_impl
(
const
scalar_t
*
input
,
int8_t
*
output
,
float
*
scale
,
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
constexpr
float
i8_min
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
());
constexpr
float
i8_max
=
static_cast
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
const
cvt_vec_t
i8_min_vec
(
i8_min
);
const
cvt_vec_t
i8_max_vec
(
i8_max
);
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
cvt_vec_t
max_value
(
std
::
numeric_limits
<
float
>::
lowest
());
cvt_vec_t
min_value
(
std
::
numeric_limits
<
float
>::
max
());
{
int
j
=
0
;
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
if
constexpr
(
AZP
)
{
max_value
=
max_value
.
max
(
elems_fp32
);
min_value
=
min_value
.
min
(
elems_fp32
);
}
else
{
max_value
=
max_value
.
max
(
elems_fp32
.
abs
());
}
}
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
if
(
j
+
vec_elem_num
==
hidden_size
)
{
if
constexpr
(
AZP
)
{
max_value
=
max_value
.
max
(
elems_fp32
);
min_value
=
min_value
.
min
(
elems_fp32
);
}
else
{
max_value
=
max_value
.
max
(
elems_fp32
.
abs
());
}
}
else
{
if
constexpr
(
AZP
)
{
max_value
=
max_value
.
max
(
elems_fp32
,
hidden_size
-
j
);
min_value
=
min_value
.
min
(
elems_fp32
,
hidden_size
-
j
);
}
else
{
max_value
=
max_value
.
max
(
elems_fp32
.
abs
(),
hidden_size
-
j
);
}
}
}
float
scale_val
,
azp_val
;
if
constexpr
(
AZP
)
{
float
max_scalar
=
max_value
.
reduce_max
();
float
min_scalar
=
min_value
.
reduce_min
();
scale_val
=
(
max_scalar
-
min_scalar
)
/
255.0
f
;
azp_val
=
std
::
nearbyint
(
-
128.0
f
-
min_scalar
/
scale_val
);
azp
[
i
]
=
static_cast
<
int32_t
>
(
azp_val
);
scale
[
i
]
=
scale_val
;
}
else
{
scale_val
=
max_value
.
reduce_max
()
/
127.0
f
;
scale
[
i
]
=
scale_val
;
}
const
cvt_vec_t
inv_scale
(
1.0
/
scale_val
);
const
cvt_vec_t
azp_vec
(
azp_val
);
{
int
j
=
0
;
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
(
elems_fp32
*
inv_scale
);
if
constexpr
(
AZP
)
{
elems_fp32
=
elems_fp32
+
azp_vec
;
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
);
}
load_vec_t
elems
(
input
+
i
*
hidden_size
+
j
);
cvt_vec_t
elems_fp32
(
elems
);
elems_fp32
=
(
elems_fp32
*
inv_scale
);
if
constexpr
(
AZP
)
{
elems_fp32
=
elems_fp32
+
azp_vec
;
}
elems_fp32
=
elems_fp32
.
clamp
(
i8_min_vec
,
i8_max_vec
);
vec_op
::
INT8Vec16
elems_int8
(
elems_fp32
);
elems_int8
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
}
template
<
bool
PerChannel
,
typename
scalar_t
>
void
static_quant_epilogue
(
const
float
*
input
,
scalar_t
*
output
,
const
float
a_scale
,
const
float
*
b_scale
,
const
int32_t
*
azp_with_adj
,
const
int
num_tokens
,
const
int
hidden_size
)
{
CPU_KERNEL_GUARD_IN
(
dynamic_output_scale_impl
)
using
load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
load_vec_type
;
using
azp_adj_load_vec_t
=
typename
KernelVecType
<
scalar_t
>::
azp_adj_load_vec_type
;
using
cvt_vec_t
=
typename
KernelVecType
<
scalar_t
>::
cvt_vec_type
;
constexpr
int
vec_elem_num
=
load_vec_t
::
VEC_ELEM_NUM
;
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
cvt_vec_t
a_scale_vec
(
a_scale
);
cvt_vec_t
b_scale_vec
(
*
b_scale
);
cvt_vec_t
scale_vec
=
a_scale_vec
*
b_scale_vec
;
int
j
=
0
;
for
(;
j
<
hidden_size
-
vec_elem_num
;
j
+=
vec_elem_num
)
{
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
azp_adj_load_vec_t
azp_adj_vec
(
azp_with_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
if
constexpr
(
PerChannel
)
{
b_scale_vec
=
cvt_vec_t
(
b_scale
+
j
);
scale_vec
=
b_scale_vec
*
a_scale_vec
;
}
elems_fp32
=
elems_fp32
-
scale_vec
*
azp_adj_fp32
;
load_vec_t
elems_out
(
elems_fp32
);
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
);
}
cvt_vec_t
elems_fp32
(
input
+
i
*
hidden_size
+
j
);
azp_adj_load_vec_t
azp_adj_vec
(
azp_with_adj
+
j
);
cvt_vec_t
azp_adj_fp32
(
azp_adj_vec
);
if
constexpr
(
PerChannel
)
{
b_scale_vec
=
cvt_vec_t
(
b_scale
+
j
);
scale_vec
=
b_scale_vec
*
a_scale_vec
;
}
elems_fp32
=
elems_fp32
-
scale_vec
*
azp_adj_fp32
;
load_vec_t
elems_out
(
elems_fp32
);
elems_out
.
save
(
output
+
i
*
hidden_size
+
j
,
hidden_size
-
j
);
}
}
template
<
bool
AZP
,
bool
PerChannel
,
bool
Bias
,
typename
scalar_t
>
void
dynamic_quant_epilogue
(
const
float
*
input
,
scalar_t
*
output
,
const
float
*
a_scale
,
const
float
*
b_scale
,
...
...
@@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const
float
*
scale
,
const
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"static_scaled_int8_quant_impl requires AVX512 support."
)
TORCH_CHECK
(
false
,
"static_scaled_int8_quant_impl requires AVX512/powerpc64 support."
)
}
template
<
typename
scalar_t
>
...
...
@@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float
*
scale
,
int32_t
*
azp
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"dynamic_scaled_int8_quant_impl requires AVX512 support."
)
TORCH_CHECK
(
false
,
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support."
)
}
template
<
bool
PerChannel
,
typename
scalar_t
>
...
...
@@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output,
const
float
a_scale
,
const
float
*
b_scale
,
const
int32_t
*
azp_with_adj
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"static_quant_epilogue requires AVX512 support."
)
TORCH_CHECK
(
false
,
"static_quant_epilogue requires AVX512
/powerpc64
support."
)
}
template
<
typename
scalar_t
>
...
...
@@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
const
int32_t
*
azp
,
const
int32_t
*
azp_with_adj
,
const
scalar_t
*
bias
,
const
int
num_tokens
,
const
int
hidden_size
)
{
TORCH_CHECK
(
false
,
"dynamic_quant_epilogue requires AVX512 support."
)
TORCH_CHECK
(
false
,
"dynamic_quant_epilogue requires AVX512/powerpc64 support."
)
}
#endif
}
// namespace
...
...
@@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant(
}
});
}
#if defined(__powerpc64__)
void
int8_scaled_mm_ppc64le
(
torch
::
Tensor
&
c
,
// [M, OC], row-major
const
torch
::
Tensor
&
a
,
// [M, IC], row-major
const
torch
::
Tensor
&
b
,
// [IC, OC], column-major
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
// [OC]
)
{
CPU_KERNEL_GUARD_IN
(
cutlass_scaled_mm
)
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
&&
b
.
dtype
()
==
torch
::
kInt8
,
"int8_scaled_mm_ppc64le only supports INT8 inputs."
);
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
// We dont need this
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
b
.
size
(
1
)
&&
bias
->
is_contiguous
()
&&
bias
->
dim
()
==
1
);
}
VLLM_DISPATCH_FLOATING_TYPES
(
c
.
scalar_type
(),
"int8_scaled_mm_ppc64le"
,
[
&
]
{
torch
::
Tensor
tmp_fp32_out
=
torch
::
empty_like
(
c
,
::
at
::
ScalarType
::
Float
);
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper
<
true
>::
gemm_s8s8_jit
<
float
,
void
>
(
a
.
data_ptr
<
int8_t
>
(),
b
.
data_ptr
<
int8_t
>
(),
tmp_fp32_out
.
data_ptr
<
float
>
(),
nullptr
,
a
.
size
(
0
),
b
.
size
(
1
),
a
.
size
(
1
),
nullptr
,
b_scales
.
data_ptr
<
float
>
(),
0
,
b_scales
.
numel
());
if
(
bias
.
has_value
())
{
// Compute C=s_a * C_inter + bias
dynamic_quant_epilogue
<
false
,
true
,
true
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
nullptr
,
nullptr
,
nullptr
,
bias
->
data_ptr
<
scalar_t
>
(),
c
.
size
(
0
),
c
.
size
(
1
));
}
else
{
// Compute C=s_a * C_inter
dynamic_quant_epilogue
<
false
,
true
,
false
,
scalar_t
>
(
tmp_fp32_out
.
data_ptr
<
float
>
(),
c
.
data_ptr
<
scalar_t
>
(),
a_scales
.
data_ptr
<
float
>
(),
nullptr
,
nullptr
,
nullptr
,
nullptr
,
c
.
size
(
0
),
c
.
size
(
1
));
}
});
}
#endif
csrc/cpu/torch_bindings.cpp
View file @
7a985548
...
...
@@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const
std
::
optional
<
torch
::
Tensor
>&
azp
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
#if defined(__powerpc64__)
void
int8_scaled_mm_ppc64le
(
torch
::
Tensor
&
c
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_scales
,
const
torch
::
Tensor
&
b_scales
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
#endif
void
mla_decode_kvcache
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
kv_cache
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
);
...
...
@@ -117,7 +125,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!
?
key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCPU
,
&
rotary_embedding
);
...
...
@@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm_azp"
,
torch
::
kCPU
,
&
int8_scaled_mm_azp
);
#elif defined(__powerpc64__)
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()"
);
ops
.
impl
(
"static_scaled_int8_quant"
,
torch
::
kCPU
,
&
static_scaled_int8_quant
);
// Compute int8 quantized tensor and scaling factor
ops
.
def
(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()"
);
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCPU
,
&
dynamic_scaled_int8_quant
);
// W8A8 GEMM, supporting symmetric quantization.
ops
.
def
(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm"
,
torch
::
kCPU
,
&
int8_scaled_mm_ppc64le
);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops
.
def
(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm_azp"
,
torch
::
kCPU
,
&
int8_scaled_mm_azp
);
#endif
// SHM CCL
...
...
csrc/cutlass_extensions/common.hpp
View file @
7a985548
...
...
@@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm100_only
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
csrc/dispatch_utils.h
View file @
7a985548
...
...
@@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
csrc/layernorm_kernels.cu
View file @
7a985548
...
...
@@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
deleted
100644 → 0
View file @
45d3785c
#include "marlin_moe_kernel_ku4.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
true
;
if
(
false
)
{
}
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
16
,
4
,
256
)
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
8
,
8
,
256
)
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
8
,
4
,
128
)
AWQ_CALL_IF_MOE
(
vllm
::
kU4
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
deleted
100644 → 0
View file @
45d3785c
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
deleted
100644 → 0
View file @
45d3785c
#include "marlin_moe_kernel_ku4b8.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU4B8
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
deleted
100644 → 0
View file @
45d3785c
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku4b8
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
deleted
100644 → 0
View file @
45d3785c
#include "marlin_moe_kernel_ku8b128.h"
namespace
marlin_moe
{
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
)
{
bool
has_zp
=
false
;
if
(
false
)
{
}
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
8
,
4
,
128
)
GPTQ_CALL_IF_MOE
(
vllm
::
kU8B128
,
4
,
8
,
128
)
else
{
return
false
;
}
return
true
;
}
}
// namespace marlin_moe
csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
deleted
100644 → 0
View file @
45d3785c
#pragma once
#include "marlin_moe_kernel.h"
namespace
marlin_moe
{
bool
call_marlin_moe_kernel_ku8b128
(
vllm
::
ScalarType
const
&
q_type
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
has_act_order
,
int
group_blocks
,
int
num_threads
,
int
blocks
,
int
max_shared_mem
,
cudaStream_t
stream
,
const
int4
*
A_ptr
,
const
int4
*
B_ptr
,
int4
*
C_ptr
,
const
int
*
sorted_ids_ptr
,
const
float
*
topk_weights_ptr
,
const
int4
*
s_ptr
,
const
int4
*
zp_ptr
,
const
int
*
g_idx_ptr
,
int
*
expert_offsets_ptr
,
int
num_groups
,
int
expert_idx
,
int
num_experts
,
int
topk
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
tot_m
,
int
*
locks
,
bool
replicate_input
,
bool
apply_weights
,
int
m_block
,
int
max_par
,
int
cfg_max_m_blocks
);
}
csrc/moe/marlin_moe_ops.cu
deleted
100644 → 0
View file @
45d3785c
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
namespace
marlin_moe
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
}
int
cur_block_rows
=
finish_row
-
start_row
;
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
blockDim
.
x
;
int
rest
=
size_k
%
blockDim
.
x
;
int
offset
=
row
*
row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
base_k
+=
blockDim
.
x
;
}
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
}
}
};
for
(
int
i
=
0
;
i
<
cur_block_rows
;
i
++
)
{
int
cur_row
=
start_row
+
i
;
if
(
cur_row
<
size_m
)
{
permute_row
(
cur_row
);
}
}
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
int
expert_id
=
threadIdx
.
x
;
int
num_experts
=
blockDim
.
x
;
int
occurrences
=
0
;
for
(
int
i
=
0
;
i
<
topk_length
;
++
i
)
{
occurrences
+=
(
topk_ids
[
i
]
==
expert_id
);
}
expert_offsets
[
expert_id
+
1
]
=
occurrences
;
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
tot_offset
=
0
;
expert_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tot_offset
+=
ceildiv
(
expert_offsets
[
i
+
1
],
block_size
)
*
block_size
;
expert_offsets
[
i
+
1
]
=
tot_offset
;
}
}
__syncthreads
();
}
#else
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
__global__
void
compute_expert_offsets
(
int
const
*
__restrict__
topk_ids
,
int
*
__restrict__
expert_offsets
,
int
topk_length
,
int
block_size
)
{
// Marlin is not implemented yet for SM < 8.0
assert
(
false
);
return
;
}
#endif
typedef
struct
{
int
thread_k
;
int
thread_n
;
int
num_threads
;
}
thread_config_t
;
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
// Default
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
256
,
256
},
// Reduce K 2X, increase N 2X
{
64
,
128
,
128
},
// Reduce K 2X, same N
{
64
,
64
,
128
},
// Reduce both 2X
};
thread_config_t
large_batch_thread_configs
[]
=
{
// Ordered by priority
// thread_k, thread_n, num_threads
{
64
,
256
,
256
},
// Default
{
128
,
128
,
256
},
// Reduce N 2X, increase K 2X
{
64
,
128
,
128
},
// Reduce N 2X, same K
{
128
,
64
,
128
},
// Reduce N 4X, increase K 2X
{
64
,
64
,
128
},
// Reduce N 4X, same K
};
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
)
{
bool
cache_scales_chunk
=
has_act_order
&&
!
is_k_full
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_k
=
th_config
.
thread_k
;
// Get max scale groups per thread-block
int
tb_groups
;
if
(
group_size
==
-
1
)
{
tb_groups
=
1
;
}
else
if
(
group_size
==
0
)
{
tb_groups
=
ceildiv
(
tb_k
,
32
);
// Worst case is 32 group size
}
else
{
tb_groups
=
ceildiv
(
tb_k
,
group_size
);
}
if
(
cache_scales_chunk
)
{
int
load_groups
=
tb_groups
*
STAGES
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
4
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
STAGES
;
}
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
ceildiv
(
prob_m
,
16
);
int
tb_max_m
=
16
;
while
(
true
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
STAGES
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
return
false
;
}
// Verify K/N are divisible by thread K/N
if
(
prob_k
%
th_config
.
thread_k
!=
0
||
prob_n
%
th_config
.
thread_n
!=
0
)
{
return
false
;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if
(
th_config
.
thread_k
!=
128
&&
th_config
.
thread_k
!=
64
)
{
return
false
;
}
// Verify min for thread K/N
if
(
th_config
.
thread_n
<
min_thread_n
||
th_config
.
thread_k
<
min_thread_k
)
{
return
false
;
}
// num_threads must be at least 128 (= 4 warps)
if
(
th_config
.
num_threads
<
128
)
{
return
false
;
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
int
max_m_blocks
=
4
;
while
(
max_m_blocks
>
0
)
{
if
(
prob_m
<=
16
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
else
{
for
(
auto
th_config
:
large_batch_thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
// usage
}
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void
marlin_mm_moe
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
const
void
*
sorted_ids
,
const
void
*
topk_weights
,
const
void
*
topk_ids
,
const
void
*
s
,
void
*
zp
,
const
void
*
g_idx
,
const
void
*
perm
,
void
*
a_tmp
,
void
*
expert_offsets
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
num_experts
,
int
topk
,
int
moe_block_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
replicate_input
,
bool
apply_weights
)
{
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
if
(
sms
==
-
1
)
{
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
dev
);
}
int
max_shared_mem
=
0
;
cudaDeviceGetAttribute
(
&
max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
int
num_bits
=
q_type
.
size_bits
();
// Set thread config
exec_config_t
exec_cfg
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
USER_THREADS
}};
}
else
{
// Auto config
exec_cfg
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
);
}
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
blocks
=
sms
;
TORCH_CHECK
(
prob_n
%
thread_n
==
0
,
"prob_n = "
,
prob_n
,
" is not divisible by thread_n = "
,
thread_n
);
TORCH_CHECK
(
prob_k
%
thread_k
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by thread_k = "
,
thread_k
);
int
group_blocks
=
0
;
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
group_size
!=
-
1
);
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
else
{
TORCH_CHECK
(
group_size
==
0
);
group_blocks
=
0
;
}
}
else
{
if
(
group_size
==
-
1
)
{
group_blocks
=
-
1
;
}
else
{
group_blocks
=
group_size
/
16
;
TORCH_CHECK
(
prob_k
%
group_blocks
==
0
,
"prob_k = "
,
prob_k
,
" is not divisible by group_blocks = "
,
group_blocks
);
}
}
int
tot_m
=
prob_m
;
const
int
*
topk_ids_ptr
=
(
const
int
*
)
topk_ids
;
int
*
expert_offsets_ptr
=
(
int
*
)
expert_offsets
;
compute_expert_offsets
<<<
1
,
num_experts
,
0
,
stream
>>>
(
topk_ids_ptr
,
expert_offsets_ptr
,
tot_m
*
topk
,
moe_block_size
);
bool
do_permute_a
=
has_act_order
;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if
(
is_k_full
)
{
has_act_order
=
false
;
}
int
pack_factor
=
32
/
q_type
.
size_bits
();
for
(
int
expert_idx
=
0
;
expert_idx
<
num_experts
;
++
expert_idx
)
{
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
int4
*
a_tmp_ptr
=
(
int4
*
)
a_tmp
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
+
(
prob_n
*
prob_k
/
(
pack_factor
*
4
))
*
expert_idx
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
float
*
topk_weights_ptr
=
(
const
float
*
)
topk_weights
;
const
int
*
sorted_ids_ptr
=
(
const
int
*
)
sorted_ids
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
+
num_groups
*
prob_n
/
8
*
expert_idx
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
+
num_groups
*
prob_n
/
(
pack_factor
*
4
)
*
expert_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
+
prob_k
*
expert_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
+
prob_k
*
expert_idx
;
int
*
locks
=
(
int
*
)
workspace
;
if
(
do_permute_a
)
{
// Permute A columns
int
topk_rows
=
replicate_input
?
tot_m
:
tot_m
*
topk
;
int
block_rows
=
ceildiv
(
topk_rows
,
blocks
);
permute_cols_kernel
<<<
blocks
,
num_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
topk_rows
,
prob_k
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
}
int
tot_m_blocks
=
ceildiv
(
tot_m
,
16
);
for
(
int
m_block
=
0
;
m_block
<
tot_m_blocks
;
m_block
+=
4
*
exec_cfg
.
max_m_blocks
)
{
if
(
false
)
{
}
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4b8
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku8b128
)
CALL_MOE_KERNEL_FUNCTION
(
call_marlin_moe_kernel_ku4
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
", has_act_order = "
+
str
(
has_act_order
)
+
", num_groups = "
+
str
(
num_groups
)
+
", group_size = "
+
str
(
group_size
)
+
", thread_n_blocks = "
+
str
(
thread_n_blocks
)
+
", thread_k_blocks = "
+
str
(
thread_k_blocks
));
}
}
}
}
}
// namespace marlin_moe
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
bool
has_zp
=
b_zeros
.
size
(
1
)
!=
0
;
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128. Got = "
,
b_q_type
.
str
());
}
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
int
max_par
=
4
;
int
dev
=
a
.
get_device
();
auto
options_dtype
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
a
.
device
());
torch
::
Tensor
c
=
torch
::
zeros
({
size_m
,
topk
,
size_n
},
options_dtype
);
torch
::
Tensor
a_tmp
=
replicate_input
?
torch
::
zeros
({
size_m
,
size_k
},
options_dtype
)
:
torch
::
zeros
({
size_m
,
topk
,
size_k
},
options_dtype
);
torch
::
Tensor
expert_offsets
=
torch
::
empty
({
num_experts
+
1
},
options_int
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_k
=
-
1
;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int
thread_n
=
-
1
;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int
sms
=
-
1
;
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
group_size
=
-
1
;
bool
has_act_order
=
g_idx
.
size
(
1
)
!=
0
;
int
b_rank
=
b_scales
.
sizes
().
size
();
TORCH_CHECK
(
b_rank
==
3
,
"b_scales rank = "
,
b_rank
,
" is not 3"
);
TORCH_CHECK
(
b_scales
.
size
(
2
)
==
size_n
,
"b_scales dim 2 = "
,
b_scales
.
size
(
2
),
" is not size_n = "
,
size_n
);
num_groups
=
b_scales
.
size
(
1
);
TORCH_CHECK
(
VLLM_IMPLIES
(
!
is_k_full
,
has_act_order
),
"if is_k_full is false, has_act_order must be true"
);
if
(
has_act_order
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by num_groups = "
,
num_groups
);
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
0
;
}
}
else
{
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by b_scales.size(0) = "
,
b_scales
.
size
(
0
));
group_size
=
size_k
/
num_groups
;
}
else
{
group_size
=
-
1
;
}
}
// Verify b_zeros
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
3
,
"b_zeros rank = "
,
rank
,
" is not 3"
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
num_groups
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
2
)
==
size_n
/
pack_factor
,
"b_zeros dim 2 = "
,
b_zeros
.
size
(
2
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
marlin_moe
::
marlin_mm_moe
(
a
.
data_ptr
(),
b_q_weights
.
data_ptr
(),
c
.
data_ptr
(),
sorted_ids
.
data_ptr
(),
topk_weights
.
data_ptr
(),
topk_ids
.
data_ptr
(),
b_scales
.
data_ptr
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
expert_offsets
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
num_experts
,
topk
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm_moe"
,
&
marlin_gemm_moe
);
}
csrc/moe/marlin_moe_wna16/.gitignore
0 → 100644
View file @
7a985548
kernel_*.cu
\ No newline at end of file
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
7a985548
...
...
@@ -25,15 +25,16 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
]
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
,
"vllm::kFE4M3fn"
,
"vllm::kFE2M1f"
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
)]
THREAD_M_BLOCKS
=
[
0.5
,
1
,
2
,
3
,
4
]
...
...
@@ -41,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS
=
[
0
,
-
1
,
2
,
4
,
8
]
GROUP_BLOCKS
=
[
0
,
-
1
,
1
,
2
,
4
,
8
]
DTYPES
=
[
"fp16"
,
"bf16"
]
...
...
@@ -52,21 +53,35 @@ def remove_old_kernels():
def
generate_new_kernels
():
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
has_zp
=
"B"
not
in
scalar_type
all_template_str_list
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
has_act_order
=
group_blocks
==
0
if
has_zp
and
has_act_order
:
# act order case only support gptq-int4 and gptq-int8
if
group_blocks
==
0
and
scalar_type
not
in
[
"vllm::kU4B8"
,
"vllm::kU8B128"
]:
continue
if
thread_configs
[
2
]
==
256
:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if
m_blocks
<=
1
and
thread_configs
[
0
]
!=
128
:
continue
if
m_blocks
>
1
and
thread_configs
[
0
]
!=
64
:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if
scalar_type
==
"vllm::kFE4M3fn"
and
group_blocks
not
in
[
-
1
,
8
]:
continue
# nvfp4 only supports group_size == 16
if
scalar_type
==
"vllm::kFE2M1f"
and
group_blocks
not
in
[
1
,
2
]:
continue
# other quantization methods don't support group_size = 16
if
scalar_type
!=
"vllm::kFE2M1f"
and
group_blocks
==
1
:
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
threads
=
thread_configs
[
2
]
...
...
@@ -82,8 +97,6 @@ def generate_new_kernels():
thread_k_blocks
=
k_blocks
,
m_block_size_8
=
m_blocks
==
0.5
,
stages
=
"pipe_stages"
,
has_act_order
=
has_act_order
,
has_zp
=
has_zp
,
group_blocks
=
group_blocks
,
is_zp_float
=
False
,
)
...
...
csrc/moe/marlin_moe_wna16/kernel.h
View file @
7a985548
...
...
@@ -7,18 +7,19 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace
MARLIN_NAMESPACE_NAME
{
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
...
...
@@ -33,11 +34,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
MARLIN_KERNEL_PARAMS
);
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
7a985548
...
...
@@ -25,6 +25,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -77,8 +76,8 @@ __global__ void Marlin(
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
bool
use_fp32_reduce
,
// whether to use fp32 global reduce
int
max_shared_mem
)
{}
}
// namespace MARLIN_NAMESPACE_NAME
...
...
@@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
template
<
typename
scalar_t
,
int
bit
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant
(
int
q
,
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
4
>
(
int
q
,
typename
ScalarType
<
half
>::
FragB
&
frag_b
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
4
>
(
int
q
,
typename
ScalarType
<
nv_bfloat16
>::
FragB
&
frag_b
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
8
>
(
int
q
,
typename
ScalarType
<
half
>::
FragB
&
frag_b
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
8
>
(
int
q
,
typename
ScalarType
<
nv_bfloat16
>::
FragB
&
frag_b
)
{
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388736.
f
;
fp32_intermediates
[
1
]
-=
8388736.
f
;
fp32_intermediates
[
2
]
-=
8388736.
f
;
fp32_intermediates
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template
<
typename
scalar_t
>
...
...
@@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -442,9 +301,11 @@ __global__ void Marlin(
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
uint16_t
*
__restrict__
scale2_ptr
,
// fp16 global scale (for nvfp4
// only)
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int32_t
*
__restrict__
sorted_token_ids_ptr
,
// moe sorted_ids
const
int32_t
*
__restrict__
expert_ids_ptr
,
// moe expert ids
const
int32_t
*
__restrict__
num_tokens_past_padded_ptr
,
// moe num tokens
...
...
@@ -458,8 +319,8 @@ __global__ void Marlin(
int
prob_k
,
// reduction dimension k
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
bool
use_fp32_reduce
,
// whether to use fp32 global reduce
int
max_shared_mem
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
...
...
@@ -481,13 +342,26 @@ __global__ void Marlin(
extern
__shared__
int4
sh
[];
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
bool
has_zp
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
;
constexpr
bool
is_int_type
=
w_type
==
vllm
::
kU4
||
w_type
==
vllm
::
kU8
||
w_type
==
vllm
::
kU4B8
||
w_type
==
vllm
::
kU8B128
;
// see comments of dequant.h for more details
constexpr
bool
dequant_skip_flop
=
!
is_int_type
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
w_type
==
vllm
::
kU8
);
scalar_t2
global_scale
;
constexpr
bool
has_act_order
=
group_blocks
==
0
;
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
static_assert
(
thread_m_blocks
==
1
||
!
m_block_size_8
);
constexpr
int
moe_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
const
int
group_size
=
(
!
has_act_order
&&
group_blocks
==
-
1
)
?
prob_k
:
prob_k
/
num_groups
;
const
int
scales_expert_stride
=
prob_n
*
prob_k
/
group_size
/
8
;
const
int
scales_expert_stride
=
prob_n
*
prob_k
/
group_size
/
(
w_type
==
vllm
::
kFE2M1f
?
16
:
8
);
const
int
zp_expert_stride
=
is_zp_float
?
prob_n
*
prob_k
/
group_size
/
8
:
prob_n
*
prob_k
/
group_size
/
(
pack_factor
*
4
);
...
...
@@ -534,13 +408,20 @@ __global__ void Marlin(
int64_t
B_expert_off
=
0
;
int4
*
sh_block_sorted_ids_int4
=
sh
;
int4
*
sh_rd_block_sorted_ids_int4
=
sh_block_sorted_ids_int4
+
moe_block_size
/
4
;
int4
*
sh_block_topk_weights_int4
=
sh_rd_block_sorted_ids_int4
+
moe_block_size
/
4
;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4
*
sh_new
=
sh_block_topk_weights_int4
+
moe_block_size
/
2
+
moe_block_size
;
int32_t
*
sh_block_sorted_ids
=
reinterpret_cast
<
int
*>
(
sh_block_sorted_ids_int4
);
int
4
*
sh_block_
topk_weights_int4
=
sh_block_sorted_ids_int4
+
moe_block_size
/
4
;
int
32_t
*
sh_
rd_
block_
sorted_ids
=
reinterpret_cast
<
int
*>
(
sh_rd_block_sorted_ids_int4
)
;
scalar_t2
*
sh_block_topk_weights
=
reinterpret_cast
<
scalar_t2
*>
(
sh_block_topk_weights_int4
);
int4
*
sh_new
=
sh_block_topk_weights_int4
+
moe_block_size
/
4
;
int32_t
block_num_valid_tokens
=
0
;
int32_t
locks_off
=
0
;
...
...
@@ -584,12 +465,24 @@ __global__ void Marlin(
sh_block_sorted_ids_int4
[
tid4
]
=
reinterpret_cast
<
const
int4
*>
(
sorted_token_ids_ptr
)[
block_id
*
moe_block_size
/
4
+
tid4
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
sh_rd_block_sorted_ids
[
tid4
*
4
+
i
]
=
sh_block_sorted_ids
[
tid4
*
4
+
i
]
/
top_k
;
if
(
mul_topk_weights
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
sh_block_topk_weights
[
tid4
*
4
+
i
]
=
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
tid4
*
4
+
i
]]));
int
idx
=
tid4
*
4
+
i
;
idx
=
idx
<
block_num_valid_tokens
?
idx
:
0
;
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
sh_block_topk_weights
[
idx
]
=
__hmul2
(
global_scale
,
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
idx
]])));
}
else
{
sh_block_topk_weights
[
idx
]
=
Dtype
::
num2num2
(
Dtype
::
float2num
(
topk_weights_ptr
[
sh_block_sorted_ids
[
idx
]]));
}
}
}
}
...
...
@@ -620,6 +513,11 @@ __global__ void Marlin(
expert_id
=
expert_ids_ptr
[
block_id
];
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
uint16_t
val
=
scale2_ptr
[
expert_id
];
global_scale
=
Dtype
::
num2num2
(
*
reinterpret_cast
<
scalar_t
*>
(
&
val
));
}
B_expert_off
=
expert_id
*
prob_n
*
prob_k
/
(
pack_factor
*
4
);
scales_ptr
+=
(
expert_id
-
old_expert_id
)
*
scales_expert_stride
;
if
constexpr
(
has_zp
)
{
...
...
@@ -733,7 +631,7 @@ __global__ void Marlin(
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
?
thread_k_blocks
/
group_blocks
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
int
s_gl_rd_delta
=
s_gl_stride
;
...
...
@@ -743,6 +641,7 @@ __global__ void Marlin(
constexpr
int
g_idx_stage
=
has_act_order
?
(
tb_k
*
sizeof
(
int
))
/
16
:
0
;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
constexpr
int
act_s_max_num_groups
=
32
;
int
act_s_col_stride
=
1
;
int
act_s_col_warp_stride
=
act_s_col_stride
*
8
;
int
tb_n_warps
=
thread_n_blocks
/
4
;
...
...
@@ -758,9 +657,9 @@ __global__ void Marlin(
int
zp_gl_rd_delta
=
zp_gl_stride
;
// Global A read index of current thread.
int
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
)
;
a_gl_rd
+=
a_gl_rd_delta_o
*
slice_row
;
int
a_gl_rd
_row
=
threadIdx
.
x
/
a_gl_rd_delta_o
;
int
a_gl_rd_col
=
a_gl_rd_delta_o
*
slice_row
+
threadIdx
.
x
%
a_gl_rd_delta_o
;
// Shared write index of current thread.
int
a_sh_wr
=
a_sh_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
...
...
@@ -774,8 +673,8 @@ __global__ void Marlin(
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
...
...
@@ -790,11 +689,12 @@ __global__ void Marlin(
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
/
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
s_sh_wr
=
threadIdx
.
x
;
auto
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
...
...
@@ -807,17 +707,27 @@ __global__ void Marlin(
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
auto
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int
s_sh_rd
;
if
constexpr
(
group_blocks
!=
-
1
)
if
constexpr
(
group_blocks
!=
-
1
&&
w_type
==
vllm
::
kFE2M1f
)
{
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
has_zp
))
s_sh_rd
=
s_sh_rd
*
2
+
warp_row
%
2
;
}
else
if
constexpr
(
group_blocks
!=
-
1
)
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
else
if
constexpr
(
group_blocks
==
-
1
&&
(
m_block_size_8
||
(
has_zp
&&
!
dequant_skip_flop
)))
s_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
8
;
else
...
...
@@ -851,7 +761,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment?
auto
transform_a
=
[
&
](
int
i
)
{
int
row
=
i
/
a_gl_rd_delta_o
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
row
;
return
a_gl_rd_delta_o
*
row
+
(
i
%
a_gl_rd_delta_o
)
^
(
row
%
8
)
;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
...
...
@@ -879,12 +789,28 @@ __global__ void Marlin(
B_ptr
[
i
]
=
B
+
b_gl_rd_delta_i
*
i
+
b_gl_rd
;
// Shared memory storage for global fetch pipelines.
int4
*
sh_a
=
sh_new
;
int4
*
sh_b
=
sh_a
+
(
stages
*
a_sh_stage
);
int4
*
sh_g_idx
=
sh_b
+
(
stages
*
b_sh_stage
);
constexpr
int
sh_red_size
=
(
2
*
thread_n_blocks
+
1
)
*
16
*
thread_m_blocks
;
constexpr
int
sh_b_size
=
stages
*
b_sh_stage
;
int4
*
sh_b
=
sh_new
;
int4
*
sh_red
=
sh_new
;
int4
*
sh_g_idx
=
sh_b
+
(
sh_red_size
>
sh_b_size
?
sh_red_size
:
sh_b_size
);
int4
*
sh_zp
=
sh_g_idx
+
(
stages
*
g_idx_stage
);
constexpr
int
sh_s_size
=
has_act_order
?
(
act_s_max_num_groups
*
s_sh_stride
)
:
(
stages
*
s_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
int4
*
sh_red
=
sh_b
;
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert
(
thread_m_blocks
*
16
*
thread_n_blocks
*
16
/
8
<=
stages
*
b_sh_stage
);
int4
*
sh_a
=
sh_s
+
sh_s_size
;
constexpr
int
shm_size_used
=
moe_block_size
+
stages
*
(
g_idx_stage
+
zp_sh_stage
)
+
sh_s_size
+
(
sh_red_size
>
sh_b_size
?
sh_red_size
:
sh_b_size
);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int
sh_a_max_row
=
((
max_shared_mem
-
1024
)
/
16
-
shm_size_used
)
/
(
thread_k_blocks
*
2
);
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
...
...
@@ -905,15 +831,14 @@ __global__ void Marlin(
int
sh_first_group_id
=
-
1
;
int
sh_num_groups
=
-
1
;
constexpr
int
sh_max_num_groups
=
32
;
auto
fetch_act_order_scales_to_shared
=
[
&
](
bool
is_async
,
int
first_group_id
,
int
last_group_id
)
{
sh_first_group_id
=
first_group_id
;
sh_num_groups
=
last_group_id
-
first_group_id
+
1
;
if
(
sh_num_groups
<
sh
_max_num_groups
)
{
sh_num_groups
=
s
h
_max_num_groups
;
if
(
sh_num_groups
>
act_s
_max_num_groups
)
{
sh_num_groups
=
act_
s_max_num_groups
;
}
if
(
sh_first_group_id
+
sh_num_groups
>
num_groups
)
{
...
...
@@ -940,27 +865,31 @@ __global__ void Marlin(
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
int
a_remaining_load_count_in_slice
=
stages
;
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
)
{
bool
should_load_a
=
true
;
int
max_num_stage_groups
=
((
sh_a_max_row
-
moe_block_size
)
/
moe_block_size
+
1
)
/
stages
;
max_num_stage_groups
=
max
(
max_num_stage_groups
,
1
);
auto
fetch_to_shared
=
[
&
](
int
pipe
,
int
a_off
,
bool
pred
=
true
,
int
pipe_a
=
0
)
{
if
(
pred
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
if
(
prob_k
>
thread_k_blocks
*
16
*
stages
||
slice_col
==
0
||
a_remaining_load_count_in_slice
>
0
)
{
a_remaining_load_count_in_slice
--
;
if
(
should_load_a
)
{
int4
*
sh_a_stage
=
sh_a
+
moe_block_size
*
a_sh_stride
*
pipe_a
;
#pragma unroll
for
(
int
i
=
0
;
i
<
a_sh_wr_iters
;
i
++
)
{
int
a_idx
=
a_gl_rd_delta_i
*
i
+
a_gl_rd
+
a_gl_rd_delta_o
*
a_off
;
int
row
=
a_idx
/
a_gl_stride
;
int
row
=
a_gl_rd_delta_i
/
a_gl_stride
*
i
+
a_gl_rd_row
;
int64_t
sorted_row
=
0
;
if
(
!
m_block_size_8
||
row
<
8
)
sorted_row
=
sh_block_sorted_ids
[
row
]
/
top_k
;
int64_t
true_idx
=
sorted_row
*
a_gl_stride
+
a_idx
%
a_gl_stride
;
sorted_row
=
sh_rd_block_sorted_ids
[
row
];
int64_t
true_idx
=
sorted_row
*
a_gl_stride
+
a_gl_rd_col
+
a_gl_rd_delta_o
*
a_off
;
cp_async4_pred
(
&
sh_a_stage
[
a_sh_wr_trans
[
i
]],
&
A
[
true_idx
],
row
<
block_num_valid_tokens
);
}
}
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
...
...
@@ -1063,8 +992,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
)
{
int4
*
sh_a_stage
=
sh_a
+
a_sh_st
ag
e
*
pipe
;
auto
fetch_to_registers
=
[
&
](
int
k
,
int
pipe
,
int
pipe_a
=
0
)
{
int4
*
sh_a_stage
=
sh_a
+
moe_block_size
*
a_sh_st
rid
e
*
pipe
_a
;
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm
<
m_block_size_8
?
2
:
4
,
scalar_t
>
(
...
...
@@ -1109,12 +1038,17 @@ __global__ void Marlin(
}
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1123,12 +1057,19 @@ __global__ void Marlin(
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
int
cur_group_id
=
k_blocks
/
group_blocks
;
int
cur_group_id
=
k_blocks
/
(
group_blocks
*
(
w_type
==
vllm
::
kFE2M1f
?
2
:
1
));
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
if
constexpr
(
w_type_id
!=
vllm
::
kFE2M1f
.
id
())
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
else
{
reinterpret_cast
<
int2
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
reinterpret_cast
<
int2
*>
(
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
}
}
}
...
...
@@ -1152,7 +1093,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
...
...
@@ -1161,7 +1102,7 @@ __global__ void Marlin(
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
auto
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
...
...
@@ -1222,15 +1163,18 @@ __global__ void Marlin(
}
}
else
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1251,6 +1195,7 @@ __global__ void Marlin(
sh_zp_stage
+=
cur_group_id
*
zp_sh_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
num_ints_per_thread
;
i
++
)
{
frag_qzp
[
k
%
2
][
i
]
=
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
...
...
@@ -1263,12 +1208,16 @@ __global__ void Marlin(
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
if
(
k
%
b_sh_wr_iters
==
0
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1292,6 +1241,10 @@ __global__ void Marlin(
}
};
auto
dequant_data
=
[
&
](
int
q
,
scalar_t2
*
frag_b_ptr
)
{
dequant
<
scalar_t2
,
w_type_id
,
dequant_skip_flop
>
(
q
,
frag_b_ptr
);
};
// Execute the actual tensor core matmul of a sub-tile.
bool
is_first_matmul_in_slice
=
true
;
auto
matmul
=
[
&
](
int
k
)
{
...
...
@@ -1315,15 +1268,27 @@ __global__ void Marlin(
zp_quant_1
=
frag_qzp
[
k2
][
1
];
}
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
zp_quant_0
,
frag_zp_0
);
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
zp_quant_1
,
frag_zp_1
);
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
frag_zp
[
3
]
=
frag_zp_1
[
1
];
dequant_data
(
zp_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
));
dequant_data
(
zp_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_zp
)
+
2
);
}
}
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
is_zp_float
)
{
if
(
is_new_zp
)
{
reinterpret_cast
<
int4
*>
(
&
frag_zp
)[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k2
])[
0
];
}
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
]));
dequant_fp8_scales
<
scalar_t2
>
(
s_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
])
+
2
);
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
...
...
@@ -1332,7 +1297,10 @@ __global__ void Marlin(
FragB
frag_b1
;
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
if
constexpr
(
w_type_id
==
vllm
::
kFE2M1f
.
id
())
{
b_quant_1
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_0
=
b_quant_1
<<
8
;
}
else
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
...
...
@@ -1342,8 +1310,13 @@ __global__ void Marlin(
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
}
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
b_quant_0
,
frag_b0
);
dequant
<
scalar_t
,
w_type
.
size_bits
()
>
(
b_quant_1
,
frag_b1
);
dequant_data
(
b_quant_0
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b0
));
dequant_data
(
b_quant_1
,
reinterpret_cast
<
scalar_t2
*>
(
&
frag_b1
));
if
constexpr
(
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
...
...
@@ -1351,9 +1324,9 @@ __global__ void Marlin(
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
0
);
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k2
][
0
][
j
],
act_frag_s
[
k2
][
1
][
j
],
act_frag_s
[
k
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
act_frag_s
[
k
2
][
2
][
j
],
act_frag_s
[
k2
][
3
][
j
],
1
);
}
else
if
constexpr
(
!
dequant_skip_flop
&&
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
int
idx
=
(
threadIdx
.
x
/
4
)
%
2
;
scalar_t2
s2
=
Dtype
::
nums2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
[
j
/
2
][
j
%
2
*
2
+
0
])[
idx
],
...
...
@@ -1361,18 +1334,12 @@ __global__ void Marlin(
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
s2
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
s2
.
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
s2
.
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_z
p
&&
!
i
s_zp
_float
&&
group_blocks
!=
-
1
)
{
}
else
if
constexpr
(
!
dequant_skip_flo
p
&&
ha
s_zp
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zp
[
j
]
=
__hmul2
(
frag_zp
[
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k
%
2
][
j
][
0
].
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k
%
2
][
j
][
0
].
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
has_zp
&&
is_zp_float
&&
group_blocks
!=
-
1
)
{
if
(
is_new_zp
)
frag_zpf
[
k2
][
j
]
=
__hmul2
(
frag_zpf
[
k2
][
j
],
*
reinterpret_cast
<
scalar_t2
*>
(
&
frag_s
[
k2
][
j
]));
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
].
x
,
frag_zpf
[
k2
][
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
].
y
,
frag_zpf
[
k2
][
j
].
y
);
scale_and_sub
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
][
0
].
x
,
frag_zp
[
j
].
x
);
scale_and_sub
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
][
0
].
y
,
frag_zp
[
j
].
y
);
}
else
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
<
scalar_t
>
(
frag_b0
,
frag_s
[
k2
][
j
],
0
);
scale
<
scalar_t
>
(
frag_b1
,
frag_s
[
k2
][
j
],
1
);
...
...
@@ -1397,7 +1364,7 @@ __global__ void Marlin(
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
...
...
@@ -1634,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
4
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
if
constexpr
(
w_type
==
vllm
::
kFE2M1f
)
{
if
(
!
mul_topk_weights
)
{
res
=
__hmul2
(
res
,
global_scale
);
}
}
if
constexpr
(
m_block_size_8
)
{
((
scalar_t
*
)
sh_red
)[
idx
]
=
res
.
x
;
((
scalar_t
*
)
sh_red
)[
idx
+
8
*
c_sh_stride
]
=
res
.
y
;
...
...
@@ -1728,10 +1702,12 @@ __global__ void Marlin(
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_col_zp_to_shared
();
fetch_col_scale_to_shared
();
if
constexpr
(
!
dequant_skip_flop
)
{
fetch_col_scale_to_shared
();
}
}
}
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
);
fetch_to_shared
(
i
,
i
,
i
<
slice_iters
,
i
);
}
zero_accums
();
...
...
@@ -1740,8 +1716,10 @@ __global__ void Marlin(
fetch_to_registers
(
0
,
0
);
fetch_scales_to_registers
(
0
,
0
);
fetch_zp_to_registers
(
0
,
0
);
a_gl_rd
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
a_gl_rd_col
+=
a_gl_rd_delta_o
*
(
stages
-
1
);
if
constexpr
(
has_act_order
)
{
slice_k_start_shared_fetch
+=
tb_k
*
(
stages
-
1
);
}
};
if
(
slice_iters
)
{
start_pipes
();
...
...
@@ -1754,43 +1732,59 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at
// index 0.
for
(
int
stage_group_id
=
0
;
stage_group_id
<
max_num_stage_groups
;
stage_group_id
++
)
{
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
fetch_to_registers
(
k
+
1
,
pipe
%
stages
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
for
(
int
k
=
0
;
k
<
b_sh_wr_iters
;
k
++
)
{
int
idx
=
(
pipe
>=
stages
&&
stage_group_id
==
max_num_stage_groups
-
1
)
?
(
pipe
-
stages
)
:
(
pipe
+
stage_group_id
*
stages
);
fetch_to_registers
(
k
+
1
,
pipe
%
stages
,
idx
);
fetch_scales_to_registers
(
k
+
1
,
pipe
);
fetch_zp_to_registers
(
k
+
1
,
pipe
);
if
(
k
==
b_sh_wr_iters
-
2
)
{
int
idx
=
(
pipe
>=
1
&&
stage_group_id
==
max_num_stage_groups
-
1
)
?
(
pipe
-
1
)
:
(
pipe
+
(
stage_group_id
+
1
)
*
stages
-
1
);
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
,
idx
);
pipe
++
;
wait_for_stage
();
init_same_group
(
pipe
%
stages
);
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
matmul
(
k
);
}
slice_iters
--
;
if
(
slice_iters
==
0
)
{
break
;
}
}
a_remaining_load_count_in_slice
=
0
;
a_gl_rd
+=
a_gl_rd_delta_o
*
stages
;
slice_k_start
+=
tb_k
*
stages
;
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
a_gl_rd_col
+=
a_gl_rd_delta_o
*
stages
;
if
constexpr
(
has_act_order
)
{
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
if
constexpr
(
has_act_order
)
{
slice_k_start
+=
tb_k
*
stages
;
if
(
slice_k_start
<
prob_k
)
{
slice_k_start_shared_fetch
+=
tb_k
*
stages
;
int
first_group_id
=
g_idx
[
slice_k_start
];
int
last_g_idx
=
slice_k_start
+
stages
*
tb_k
*
2
;
if
(
last_g_idx
>=
prob_k
)
{
last_g_idx
=
prob_k
-
1
;
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_act_order_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
}
}
}
int
last_group_id
=
g_idx
[
last_g_idx
];
if
(
last_group_id
>=
sh_first_group_id
+
sh_num_groups
)
{
fetch_act_order_scales_to_shared
(
false
,
first_group_id
,
last_group_id
);
__syncthreads
();
if
(
slice_iters
==
0
)
{
break
;
}
}
...
...
@@ -1802,7 +1796,8 @@ __global__ void Marlin(
bool
last
=
slice_idx
==
slice_count
-
1
;
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
...
...
@@ -1812,7 +1807,8 @@ __global__ void Marlin(
}
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
!
has_zp
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
w_type
.
size_bits
()
==
8
||
(
last
||
use_atomic_add
))
{
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -1836,7 +1832,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
&&
!
has_zp
)
{
w_type
.
size_bits
()
==
8
&&
(
has_zp
&&
dequant_skip_flop
||
!
has_zp
))
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
...
...
@@ -1877,15 +1874,30 @@ __global__ void Marlin(
if
(
last
||
use_atomic_add
)
// only the last block in a slice actually writes the result
write_result
();
i
f
(
slice_row
)
a_remaining_load_count_in_slice
=
stages
;
i
nt
old_
slice_row
=
slice_row
;
slice_row
=
0
;
slice_col_par
++
;
slice_col
++
;
is_first_matmul_in_slice
=
true
;
init_slice
();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if
(
slice_col
==
0
||
old_slice_row
||
prob_k
>
thread_k_blocks
*
16
*
stages
*
max_num_stage_groups
)
{
should_load_a
=
true
;
}
else
{
should_load_a
=
false
;
}
if
(
slice_iters
)
{
a_gl_rd
=
a_gl_stride
*
(
threadIdx
.
x
/
a_gl_rd_delta_o
)
+
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
a_gl_rd_col
=
(
threadIdx
.
x
%
a_gl_rd_delta_o
);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
B_ptr
[
i
]
+=
b_sh_stride
-
b_gl_rd_delta_o
*
k_tiles
;
...
...
@@ -1900,12 +1912,10 @@ __global__ void Marlin(
slice_k_finish
=
slice_k_start
+
tb_k
*
slice_iters
;
slice_k_start_shared_fetch
=
slice_k_start
;
slice_n_offset
=
act_s_col_tb_stride
*
slice_col
;
}
else
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
zp_gl_rd
=
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
start_pipes
();
}
}
...
...
csrc/moe/marlin_moe_wna16/ops.cu
View file @
7a985548
...
...
@@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups
*
pipe_stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
...
...
@@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
}
}
int
get_kernel_cache_size
(
thread_config_t
const
&
th_config
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
)
{
int
get_kernel_cache_size
(
thread_config_t
const
&
th_config
,
bool
m_block_size_8
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
tb_m
=
thread_m_blocks
*
16
;
int
tb_m
=
thread_m_blocks
*
(
m_block_size_8
?
8
:
16
)
;
// shm size for block_sorted_ids/block_topk_weights
// shm size for
block_sorted_ids/rd_
block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int
sh_block_meta_size
=
tb_m
*
4
*
2
;
int
sh_block_meta_size
=
tb_m
*
4
;
int
sh_a_size
=
pipe_stages
*
(
tb_m
*
tb_k
)
*
2
;
int
sh_b_size
=
pipe_stages
*
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
int
sh_red_size
=
tb_m
*
(
tb_n
+
8
)
*
2
;
int
sh_s_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
...
...
@@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size
=
sh_s_size
/
2
;
}
int
total_size
=
sh_
a
_size
+
sh_
b
_size
+
sh_
s
_size
+
sh_
zp
_size
+
sh_g_idx_size
+
sh_block_meta_size
;
int
total_size
=
max
(
sh_
b
_size
,
sh_
red
_size
)
+
sh_
a
_size
+
sh_
s
_size
+
sh_zp_size
+
sh_g_idx_size
+
sh_block_meta_size
;
return
total_size
;
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
)
{
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
bool
m_block_size_8
,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
...
...
@@ -266,143 +268,129 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
return
cache_size
<=
max_shared_mem
;
}
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define FP4_GET_IF(W_TYPE) \
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template
<
typename
scalar_t
>
MarlinFuncPtr
get_marlin_kernel
(
const
vllm
::
ScalarType
q_type
,
...
...
@@ -415,23 +403,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto
kernel
=
MarlinDefault
;
if
(
false
)
{
}
GPTQ_GET_IF_M1
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_GET_IF_M1
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_GET_IF_M234
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_GET_IF_M234
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_GET_IF_M1
(
vllm
::
kU8B128
,
8
,
8
,
256
)
GPTQ_GET_IF_M1
(
vllm
::
kU8B128
,
8
,
4
,
128
)
COMMON_GET_IF
(
vllm
::
kU4
)
COMMON_GET_IF
(
vllm
::
kU4B8
)
COMMON_GET_IF
(
vllm
::
kU8B128
)
GPTQ_GET_IF_M234
(
vllm
::
kU8B128
,
16
,
4
,
256
)
GPTQ_GET_IF_M234
(
vllm
::
kU8B128
,
8
,
4
,
128
)
BIGGROUP_GET_IF
(
vllm
::
kFE4M3fn
)
AWQ_GET_IF_M1
(
vllm
::
kU4
,
8
,
8
,
256
)
AWQ_GET_IF_M1
(
vllm
::
kU4
,
8
,
4
,
128
)
FP4_GET_IF
(
vllm
::
kFE2M1f
)
A
WQ
_GET_IF
_M234
(
vllm
::
kU4
,
16
,
4
,
256
)
A
WQ
_GET_IF
_M234
(
vllm
::
kU
4
,
8
,
4
,
128
)
A
CT
_GET_IF
(
vllm
::
kU4
B8
)
A
CT
_GET_IF
(
vllm
::
kU
8B
128
)
return
kernel
;
}
...
...
@@ -457,19 +439,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for
(
int
i
=
0
;
i
<
thread_configs_size
;
i
++
)
{
thread_config_t
th_config
=
thread_configs
[
i
];
if
(
!
is_valid_config
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
))
{
if
(
!
is_valid_config
(
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
))
{
continue
;
}
int
cache_size
=
get_kernel_cache_size
(
th_config
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
);
int
group_blocks
=
0
;
if
(
!
has_act_order
)
{
group_blocks
=
group_size
==
-
1
?
-
1
:
group_size
/
16
;
group_blocks
=
group_size
==
-
1
?
-
1
:
(
group_size
/
16
)
;
}
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
...
...
@@ -501,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
void
*
s2
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
void
*
sorted_token_ids
,
void
*
expert_ids
,
void
*
num_tokens_past_padded
,
void
*
topk_weights
,
int
moe_block_size
,
int
top_k
,
bool
mul_topk_weights
,
bool
is_ep
,
...
...
@@ -520,8 +502,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
"q_type must be u4 or u8 when has_zp = True. Got = "
,
q_type
.
str
());
}
else
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
q_type
==
vllm
::
kU4B8
||
q_type
==
vllm
::
kU8B128
||
q_type
==
vllm
::
kFE4M3fn
||
q_type
==
vllm
::
kFE2M1f
,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = "
,
q_type
.
str
());
}
...
...
@@ -555,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
uint16_t
*
s2_ptr
=
(
const
uint16_t
*
)
s2
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
...
...
@@ -631,18 +616,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
)
,
"Invalid thread config: thread_m_blocks = "
,
th
re
a
d_m
_blocks
,
", thread_k
= "
,
thread_
tfg
.
thread_k
,
", thread_
n
= "
,
thread_tfg
.
thread_
n
,
",
num_
thread
s
= "
,
thread_tfg
.
num_
thread
s
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
",
"
,
prob_
n
,
"
] and num_bits = "
,
num_bits
,
",
g
ro
up_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
",
has_zp = "
,
has_zp
,
",
i
s_zp
_float
= "
,
i
s_zp
_float
,
", max_shared_mem = "
,
max_shared_mem
);
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_sha
red_m
em
)
,
"Invalid thread config: thread_m_blocks
= "
,
thread_
m_blocks
,
", thread_
k
= "
,
thread_tfg
.
thread_
k
,
", thread
_n
= "
,
thread_tfg
.
thread
_n
,
", num_threads = "
,
thread_tfg
.
num_threads
,
" for MKN = [
"
,
prob_
m
,
"
, "
,
prob_k
,
",
"
,
p
ro
b_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
",
is_k_full = "
,
is_k_full
,
",
ha
s_zp = "
,
ha
s_zp
,
", is_zp_float = "
,
is_zp_float
,
", max_shared_mem = "
,
max_shared_mem
);
auto
kernel
=
get_marlin_kernel
<
scalar_t
>
(
q_type
,
thread_m_blocks
,
thread_n_blocks
,
thread_k_blocks
,
m_block_size_8
,
...
...
@@ -663,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel
<<<
blocks
,
num_threads
,
max_shared_mem
,
stream
>>>
(
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
zp_ptr
,
g_idx_ptr
,
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
s_ptr
,
s2_ptr
,
zp_ptr
,
g_idx_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_past_padded_ptr
,
topk_weights_ptr
,
top_k
,
mul_topk_weights
,
is_ep
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
,
use_atomic_add
,
use_fp32_reduce
);
prob_n
,
prob_k
,
locks
,
use_atomic_add
,
use_fp32_reduce
,
max_shared_mem
);
// clang-format on
}
...
...
@@ -675,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch
::
Tensor
moe_wna16_marlin_gemm
(
torch
::
Tensor
&
a
,
std
::
optional
<
torch
::
Tensor
>
const
&
c_or_none
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
global_scale_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
b_zeros_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
g_idx_or_none
,
std
::
optional
<
torch
::
Tensor
>
const
&
perm_or_none
,
torch
::
Tensor
&
workspace
,
...
...
@@ -826,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
}
}
torch
::
Tensor
global_scale
;
if
(
global_scale_or_none
.
has_value
())
{
global_scale
=
global_scale_or_none
.
value
();
TORCH_CHECK
(
b_q_type
==
vllm
::
kFE2M1f
,
"global_scale can only be used for float4_e2m1f."
);
}
else
{
global_scale
=
torch
::
empty
({
0
},
options
);
TORCH_CHECK
(
!
(
b_q_type
==
vllm
::
kFE2M1f
),
"the global_scale parameter must be passed for float4_e2m1f."
);
}
torch
::
Tensor
b_zeros
;
if
(
b_zeros_or_none
.
has_value
())
{
b_zeros
=
b_zeros_or_none
.
value
();
...
...
@@ -838,13 +835,15 @@ torch::Tensor moe_wna16_marlin_gemm(
if
(
has_zp
)
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4
,
"b_q_type must be u4 when has_zp = True. Got = "
,
b_q_type
.
str
());
b_q_type
==
vllm
::
kU4
||
b_q_type
==
vllm
::
kU8
,
"b_q_type must be u4
or u8
when has_zp = True. Got = "
,
b_q_type
.
str
());
}
else
{
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = "
,
b_q_type
.
str
());
TORCH_CHECK
(
b_q_type
==
vllm
::
kU4B8
||
b_q_type
==
vllm
::
kU8B128
||
b_q_type
==
vllm
::
kFE4M3fn
||
b_q_type
==
vllm
::
kFE2M1f
,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = "
,
b_q_type
.
str
());
}
if
(
has_zp
&&
is_zp_float
)
{
...
...
@@ -889,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int
dev
=
a
.
get_device
();
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Half
>
();
}
MARLIN_NAMESPACE_NAME
::
marlin_mm
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
sorted_token_ids
.
data_ptr
(),
expert_ids
.
data_ptr
(),
num_tokens_past_padded
.
data_ptr
(),
...
...
@@ -901,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
void
*
scales_ptr
;
if
(
b_q_type
==
vllm
::
kFE2M1f
)
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
Float8_e4m3fn
>
();
}
else
{
scales_ptr
=
b_scales
.
data_ptr
<
at
::
BFloat16
>
();
}
MARLIN_NAMESPACE_NAME
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b
_scale
s
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
scales_ptr
,
global
_scale
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
sorted_token_ids
.
data_ptr
(),
expert_ids
.
data_ptr
(),
num_tokens_past_padded
.
data_ptr
(),
topk_weights
.
data_ptr
(),
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
size_m
,
size_n
,
size_k
,
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
7a985548
...
...
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}
if
(
use_global_memory
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_global_mem_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
...
...
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
}
else
if
(
use_i16
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// set dynamic shared mem
auto
kernel
=
...
...
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids
.
numel
());
});
}
else
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
,
int32_t
>
;
...
...
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK
(
num_experts
==
256
,
"sgl_moe_align_block_size kernel only supports deepseek v3."
);
VLLM_DISPATCH_INTEGRAL_TYPES
(
VLLM_DISPATCH_INTEGRAL_
AND_UNSIGNED_
TYPES
(
topk_ids
.
scalar_type
(),
"sgl_moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `cumsum` tensors
auto
options_int
=
...
...
csrc/moe/moe_permute_unpermute_op.cu
0 → 100644
View file @
7a985548
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
void
moe_permute
(
const
torch
::
Tensor
&
input
,
// [n_token, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
token_expert_indicies
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
// [topk * n_token/align_block_size_m, hidden]
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert + 1]
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
torch
::
Tensor
&
m_indices
)
{
// [align_expand_m]
TORCH_CHECK
(
topk_weights
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"topk_weights must be float32"
);
TORCH_CHECK
(
expert_first_token_offset
.
scalar_type
()
==
at
::
ScalarType
::
Long
,
"expert_first_token_offset must be int64"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
token_expert_indicies
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"token_expert_indicies must be int32"
);
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"src_row_id2dst_row_id_map must be int32"
);
TORCH_CHECK
(
expert_first_token_offset
.
size
(
0
)
==
n_local_expert
+
1
,
"expert_first_token_offset shape != n_local_expert+1"
)
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
token_expert_indicies
.
sizes
(),
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map"
);
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
align_block_size_value
=
align_block_size
.
has_value
()
?
align_block_size
.
value
()
:
-
1
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
long
sorter_size
=
CubKeyValueSorter
::
getWorkspaceSize
(
n_token
*
topk
,
n_expert
);
auto
sort_workspace
=
torch
::
empty
(
{
sorter_size
},
torch
::
dtype
(
torch
::
kInt8
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
dst_row_id2src_row_id_map
=
torch
::
empty_like
(
src_row_id2dst_row_id_map
);
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
CubKeyValueSorter
sorter
{};
int64_t
*
valid_num_ptr
=
nullptr
;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if
(
expert_map
.
has_value
())
{
const
int
*
expert_map_ptr
=
get_ptr
<
int
>
(
expert_map
.
value
());
valid_num_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
preprocessTopkIdLauncher
(
get_ptr
<
int
>
(
topk_ids
),
n_token
*
topk
,
expert_map_ptr
,
n_expert
,
stream
);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert
(
get_ptr
<
int
>
(
topk_ids
),
get_ptr
<
int
>
(
token_expert_indicies
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
});
// get m_indices and update expert_first_token_offset with align block
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
if
(
align_block_size
.
has_value
())
{
// update align_expert_first_token_offset
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
void
moe_unpermute
(
const
torch
::
Tensor
&
permuted_hidden_states
,
// [n_token * topk, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
const
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
const
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert+1]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
torch
::
Tensor
&
hidden_states
// [n_token, hidden]
)
{
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
topk_ids
.
sizes
(),
"topk_ids shape must be same as src_row_id2dst_row_id_map"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
permuted_hidden_states
.
scalar_type
()
==
hidden_states
.
scalar_type
(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map"
);
auto
n_token
=
hidden_states
.
size
(
0
);
auto
n_hidden
=
hidden_states
.
size
(
1
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int64_t
*
valid_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
MOE_DISPATCH
(
hidden_states
.
scalar_type
(),
[
&
]
{
finalizeMoeRoutingKernelLauncher
<
scalar_t
,
scalar_t
>
(
get_ptr
<
scalar_t
>
(
permuted_hidden_states
),
get_ptr
<
scalar_t
>
(
hidden_states
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int
>
(
topk_ids
),
n_token
,
n_hidden
,
topk
,
valid_ptr
,
stream
);
});
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"moe_permute"
,
&
moe_permute
);
m
.
impl
(
"moe_unpermute"
,
&
moe_unpermute
);
}
\ No newline at end of file
csrc/moe/moe_wna16_utils.h
View file @
7a985548
...
...
@@ -108,11 +108,11 @@ __device__ inline void dequant<half2, 4>(int q, half2* res) {
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
q
>>=
8
;
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
res
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo0
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
...
...
@@ -149,13 +149,13 @@ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
...
...
Prev
1
2
3
4
5
6
7
8
9
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment