Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7e63ef82
Commit
7e63ef82
authored
Jan 21, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0' into v0.14.0-dev
parents
8cbcac5d
b17039bc
Changes
686
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
644 additions
and
515 deletions
+644
-515
csrc/cpu/cpu_wna16.cpp
csrc/cpu/cpu_wna16.cpp
+9
-9
csrc/cpu/dnnl_helper.cpp
csrc/cpu/dnnl_helper.cpp
+6
-6
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
+33
-0
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
+38
-0
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
+19
-0
csrc/cpu/scratchpad_manager.cpp
csrc/cpu/scratchpad_manager.cpp
+0
-23
csrc/cpu/scratchpad_manager.h
csrc/cpu/scratchpad_manager.h
+0
-31
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+24
-0
csrc/cpu/utils.cpp
csrc/cpu/utils.cpp
+62
-15
csrc/cpu/utils.hpp
csrc/cpu/utils.hpp
+69
-22
csrc/cumem_allocator.cpp
csrc/cumem_allocator.cpp
+10
-0
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+61
-53
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+105
-66
csrc/moe/marlin_moe_wna16/.gitignore
csrc/moe/marlin_moe_wna16/.gitignore
+1
-0
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+76
-56
csrc/moe/marlin_moe_wna16/kernel.h
csrc/moe/marlin_moe_wna16/kernel.h
+14
-14
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+57
-179
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+44
-38
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+1
-1
csrc/ops.h
csrc/ops.h
+15
-2
No files found.
Too many changes to show.
To preserve performance only
686 of 686+
files are displayed.
Plain diff
Email patch
csrc/cpu/cpu_wna16.cpp
View file @
7e63ef82
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "utils.hpp"
#include "cpu/cpu_types.hpp"
#include "cpu/utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
...
...
@@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl(
// a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks
const
int32_t
n_partition_size
=
[
&
]()
{
const
int64_t
cache_size
=
cpu_utils
::
get_l2_size
();
const
int64_t
cache_size
=
cpu_utils
::
get_
available_
l2_size
();
int64_t
ps_cache_limit
=
cache_size
/
(
k_size
*
sizeof
(
scalar_t
));
int64_t
ps_thread_limit
=
n_size
/
thread_num
;
ps_cache_limit
=
...
...
@@ -179,7 +178,7 @@ void cpu_gemm_wna16_impl(
const
int64_t
b_buffer_offset
=
0
;
const
int64_t
c_buffer_offset
=
b_buffer_size
;
const
int64_t
buffer_size
=
b_buffer_size
+
c_buffer_size
;
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
->
realloc
(
buffer_size
*
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
realloc
(
buffer_size
*
thread_num
);
alignas
(
64
)
cpu_utils
::
Counter
counter
;
...
...
@@ -190,7 +189,8 @@ void cpu_gemm_wna16_impl(
scalar_t
*
__restrict__
b_buffer
=
nullptr
;
float
*
__restrict__
c_buffer
=
nullptr
;
{
uint8_t
*
buffer_ptr
=
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
()
uint8_t
*
buffer_ptr
=
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
get_data
<
uint8_t
>
()
+
thread_id
*
buffer_size
;
b_buffer
=
reinterpret_cast
<
scalar_t
*>
(
buffer_ptr
+
b_buffer_offset
);
...
...
csrc/cpu/dnnl_helper.cpp
View file @
7e63ef82
...
...
@@ -4,8 +4,8 @@
#include "common/memory_desc.hpp"
#include "common/memory.hpp"
#include "
dnnl_helper.h
"
#include "
scratchpad_manag
er.h"
#include "
cpu/utils.hpp
"
#include "
cpu/dnnl_help
er.h"
static
dnnl
::
engine
&
default_engine
()
{
static
dnnl
::
engine
engine
(
dnnl
::
engine
::
kind
::
cpu
,
0
);
...
...
@@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto
&&
[
scratchpad_storage
,
scratchpad_mem_desc
]
=
get_runtime_memory_ptr
(
5
);
scratchpad_storage
->
set_data_handle
(
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
->
get_data
<
void
>
());
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
get_data
<
void
>
());
matmul
.
execute
(
default_stream
(),
memory_cache_
);
default_stream
().
wait
();
...
...
@@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return
m_size_cache_
->
get_or_create
(
key
,
[
&
]()
{
dnnl
::
matmul
::
primitive_desc
desc
=
this
->
create_primitive_desc
(
key
,
false
);
auto
manager
=
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
();
auto
manager
=
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
();
manager
->
realloc
(
desc
.
scratchpad_desc
().
get_size
());
return
dnnl
::
matmul
(
desc
);
});
...
...
@@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto
&&
[
scratchpad_storage
,
scratchpad_mem_desc
]
=
get_runtime_memory_ptr
(
3
);
scratchpad_storage
->
set_data_handle
(
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
->
get_data
<
void
>
());
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
get_data
<
void
>
());
matmul
.
execute
(
default_stream
(),
memory_cache_
);
default_stream
().
wait
();
...
...
@@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
}
return
m_size_cache_
->
get_or_create
(
key
,
[
&
]()
{
dnnl
::
matmul
::
primitive_desc
desc
=
this
->
create_primitive_desc
(
key
,
false
);
auto
manager
=
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
();
auto
manager
=
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
();
manager
->
realloc
(
desc
.
scratchpad_desc
().
get_size
());
return
dnnl
::
matmul
(
desc
);
});
...
...
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
View file @
7e63ef82
...
...
@@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
}
}
static
void
pack_weight
(
const
scalar_t
*
__restrict__
weight
,
scalar_t
*
__restrict__
packed_weight
,
const
int32_t
output_size
,
const
int32_t
input_size
)
{
constexpr
int32_t
elem_num_per_group
=
4
/
sizeof
(
scalar_t
);
TORCH_CHECK_EQ
(
output_size
%
16
,
0
);
TORCH_CHECK_EQ
(
input_size
%
(
16
*
elem_num_per_group
),
0
);
const
int32_t
output_group_num
=
output_size
/
16
;
const
int32_t
input_32b_num
=
input_size
/
elem_num_per_group
;
for
(
int32_t
output_group_idx
=
0
;
output_group_idx
<
output_group_num
;
++
output_group_idx
)
{
const
int32_t
*
__restrict__
weight_32b
=
reinterpret_cast
<
const
int32_t
*>
(
weight
);
int32_t
*
__restrict__
packed_weight_32b
=
reinterpret_cast
<
int32_t
*>
(
packed_weight
);
for
(
int32_t
output_idx
=
0
;
output_idx
<
16
;
++
output_idx
)
{
for
(
int32_t
weight_offset
=
0
,
packed_offset
=
0
;
weight_offset
<
input_32b_num
;
++
weight_offset
,
packed_offset
+=
16
)
{
packed_weight_32b
[
packed_offset
]
=
weight_32b
[
weight_offset
];
}
// update
weight_32b
+=
input_32b_num
;
packed_weight_32b
+=
1
;
}
// update
weight
+=
16
*
input_size
;
packed_weight
+=
16
*
input_size
;
}
}
private:
alignas
(
64
)
__tilecfg
amx_tile_config_
;
int32_t
curr_m_
;
...
...
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
View file @
7e63ef82
...
...
@@ -13,6 +13,9 @@ namespace cpu_micro_gemm {
#define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous
// blocks, means the logical shape of blocks is [16, input_size]. And the actual
// layout of blocks can be ISA-specific.
template
<
cpu_utils
::
ISA
isa
,
typename
scalar_t
>
class
MicroGemm
{
public:
...
...
@@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
curr_d
+=
ldd
;
}
}
template
<
int32_t
n_size
,
typename
scalar_t
>
FORCE_INLINE
void
add_bias_epilogue
(
float
*
c_ptr
,
float
*
d_ptr
,
scalar_t
*
__restrict__
bias_ptr
,
const
int32_t
m
,
const
int64_t
ldc
,
const
int64_t
ldd
)
{
using
scalar_vec_t
=
typename
cpu_utils
::
VecTypeTrait
<
scalar_t
>::
vec_t
;
static_assert
(
n_size
%
16
==
0
);
constexpr
int32_t
n_group_num
=
n_size
/
16
;
static_assert
(
n_group_num
<=
16
);
vec_op
::
FP32Vec16
bias_vecs
[
n_group_num
];
scalar_t
*
__restrict__
curr_bias
=
bias_ptr
;
vec_op
::
unroll_loop
<
int32_t
,
n_group_num
>
([
&
](
int32_t
i
)
{
scalar_vec_t
vec
(
curr_bias
);
bias_vecs
[
i
]
=
vec_op
::
FP32Vec16
(
vec
);
curr_bias
+=
16
;
});
float
*
curr_c
=
c_ptr
;
float
*
curr_d
=
d_ptr
;
for
(
int32_t
i
=
0
;
i
<
m
;
++
i
)
{
float
*
curr_c_iter
=
curr_c
;
float
*
curr_d_iter
=
curr_d
;
vec_op
::
unroll_loop
<
int32_t
,
n_group_num
>
([
&
](
int32_t
n_g_idx
)
{
vec_op
::
FP32Vec16
c_vec_fp32
(
curr_c_iter
);
c_vec_fp32
=
c_vec_fp32
+
bias_vecs
[
n_g_idx
];
c_vec_fp32
.
save
(
curr_d_iter
);
curr_c_iter
+=
16
;
curr_d_iter
+=
16
;
});
curr_c
+=
ldc
;
curr_d
+=
ldd
;
}
}
}
// namespace cpu_micro_gemm
#endif
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
View file @
7e63ef82
...
...
@@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
void
gemm
(
DEFINE_CPU_MICRO_GEMM_PARAMS
)
{
TileGemm82
<
scalar_t
>::
gemm
(
CPU_MICRO_GEMM_PARAMS
);
}
// Note: pack contiguous weight [output_size, input_size] as contiguous
// packed weight [output_size / 16, input_size, 16]
static
void
pack_weight
(
const
scalar_t
*
__restrict__
weight
,
scalar_t
*
__restrict__
packed_weight
,
const
int32_t
output_size
,
const
int32_t
input_size
)
{
TORCH_CHECK_EQ
(
output_size
%
16
,
0
);
for
(
int32_t
o_idx
=
0
;
o_idx
<
output_size
;
++
o_idx
)
{
const
scalar_t
*
__restrict__
curr_weight
=
weight
+
o_idx
*
input_size
;
scalar_t
*
__restrict__
curr_packed_weight
=
packed_weight
+
(
o_idx
/
16
)
*
(
16
*
input_size
)
+
o_idx
%
16
;
for
(
int32_t
i_idx
=
0
;
i_idx
<
input_size
;
++
i_idx
)
{
*
curr_packed_weight
=
*
curr_weight
;
curr_packed_weight
+=
16
;
++
curr_weight
;
}
}
}
};
}
// namespace cpu_micro_gemm
...
...
csrc/cpu/scratchpad_manager.cpp
deleted
100644 → 0
View file @
8cbcac5d
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager
::
DNNLScratchPadManager
()
:
size_
(
0
),
ptr_
(
nullptr
)
{
this
->
realloc
(
allocation_unit
*
128
);
}
void
DNNLScratchPadManager
::
realloc
(
size_t
new_size
)
{
new_size
=
round
(
new_size
);
if
(
new_size
>
size_
)
{
if
(
ptr_
!=
nullptr
)
{
std
::
free
(
ptr_
);
}
ptr_
=
std
::
aligned_alloc
(
64
,
new_size
);
size_
=
new_size
;
}
}
DNNLScratchPadManager
*
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
()
{
static
DNNLScratchPadManager
manager
;
return
&
manager
;
}
csrc/cpu/scratchpad_manager.h
deleted
100644 → 0
View file @
8cbcac5d
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class
DNNLScratchPadManager
{
public:
static
constexpr
size_t
allocation_unit
=
4
*
1024
;
// 4KB
static
DNNLScratchPadManager
*
get_dnnl_scratchpad_manager
();
DNNLScratchPadManager
();
template
<
typename
T
>
T
*
get_data
()
{
return
reinterpret_cast
<
T
*>
(
ptr_
);
}
static
size_t
round
(
size_t
size
)
{
return
((
size
+
allocation_unit
-
1
)
/
allocation_unit
)
*
allocation_unit
;
}
void
realloc
(
size_t
new_size
);
private:
size_t
size_
;
void
*
ptr_
;
};
#endif
csrc/cpu/torch_bindings.cpp
View file @
7e63ef82
...
...
@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
const
std
::
optional
<
torch
::
Tensor
>&
bias
,
const
int64_t
pack_factor
,
const
std
::
string
&
isa_hint
);
void
prepack_moe_weight
(
const
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
packed_weight
,
const
std
::
string
&
isa
);
void
cpu_fused_moe
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
w13
,
const
torch
::
Tensor
&
w2
,
const
std
::
optional
<
torch
::
Tensor
>&
w13_bias
,
const
std
::
optional
<
torch
::
Tensor
>&
w2_bias
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_id
,
const
std
::
string
&
act
,
const
std
::
string
&
isa
);
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
...
...
@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"pack_factor, str isa_hint) -> ()"
);
ops
.
impl
(
"cpu_gemm_wna16"
,
torch
::
kCPU
,
&
cpu_gemm_wna16
);
#endif
// fused moe
#if defined(__AVX512F__)
ops
.
def
(
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
"-> ()"
);
ops
.
impl
(
"prepack_moe_weight"
,
torch
::
kCPU
,
&
prepack_moe_weight
);
ops
.
def
(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"str act, str isa) -> ()"
);
ops
.
impl
(
"cpu_fused_moe"
,
torch
::
kCPU
,
&
cpu_fused_moe
);
#endif
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_utils
),
utils
)
{
...
...
csrc/cpu/utils.cpp
View file @
7e63ef82
...
...
@@ -10,7 +10,7 @@
#define gettid() syscall(SYS_gettid)
#endif
#include "cpu
_type
s.hpp"
#include "cpu
/util
s.hpp"
#ifdef VLLM_NUMA_DISABLED
std
::
string
init_cpu_threads_env
(
const
std
::
string
&
cpu_ids
)
{
...
...
@@ -24,6 +24,8 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
#ifndef VLLM_NUMA_DISABLED
std
::
string
init_cpu_threads_env
(
const
std
::
string
&
cpu_ids
)
{
bitmask
*
omp_cpu_mask
=
numa_parse_cpustring_all
(
cpu_ids
.
c_str
());
TORCH_CHECK
(
omp_cpu_mask
!=
nullptr
,
"Failed to parse CPU string: "
+
cpu_ids
);
TORCH_CHECK
(
omp_cpu_mask
->
size
>
0
);
std
::
vector
<
int
>
omp_cpu_ids
;
omp_cpu_ids
.
reserve
(
omp_cpu_mask
->
size
);
...
...
@@ -44,20 +46,12 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
// Memory node binding
if
(
numa_available
()
!=
-
1
)
{
int
mem_node_id
=
numa_node_of_cpu
(
omp_cpu_ids
.
front
());
std
::
set
<
int
>
node_ids
;
for
(
const
auto
&
cpu_id
:
omp_cpu_ids
)
{
int
node_id
=
numa_node_of_cpu
(
cpu_id
);
if
(
node_id
!=
-
1
)
{
node_ids
.
insert
(
node_id
);
}
if
(
node_id
!=
mem_node_id
)
{
TORCH_WARN
(
"CPU "
,
cpu_id
,
" is on NUMA node "
,
node_id
,
", but CPU "
,
omp_cpu_ids
.
front
(),
" is on NUMA node "
,
mem_node_id
,
". All CPUs should be on the same NUMA node for optimal "
"performance. Memory will be bound to NUMA node "
,
mem_node_id
,
"."
);
}
}
// Concatenate all node_ids into a single comma-separated string
if
(
!
node_ids
.
empty
())
{
...
...
@@ -70,7 +64,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
}
bitmask
*
mask
=
numa_parse_nodestring
(
node_ids_str
.
c_str
());
bitmask
*
src_mask
=
numa_get_mem
bin
d
();
bitmask
*
src_mask
=
numa_get_mem
s_allowe
d
();
int
pid
=
getpid
();
...
...
@@ -83,14 +77,45 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
std
::
to_string
(
errno
));
}
// restrict memory allocation node.
// Restrict memory allocation to the selected NUMA node(s).
// Enhances memory locality for the threads bound to those NUMA CPUs.
if
(
node_ids
.
size
()
>
1
)
{
errno
=
0
;
numa_set_interleave_mask
(
mask
);
if
(
errno
!=
0
)
{
TORCH_WARN
(
"numa_set_interleave_mask failed. errno: "
+
std
::
to_string
(
errno
));
}
else
{
TORCH_WARN
(
"NUMA binding: Using INTERLEAVE policy for memory "
"allocation across multiple NUMA nodes (nodes: "
+
node_ids_str
+
"). Memory allocations will be "
"interleaved across the specified NUMA nodes."
);
}
}
else
{
errno
=
0
;
numa_set_membind
(
mask
);
if
(
errno
!=
0
)
{
TORCH_WARN
(
"numa_set_membind failed. errno: "
+
std
::
to_string
(
errno
));
}
else
{
TORCH_WARN
(
"NUMA binding: Using MEMBIND policy for memory "
"allocation on the NUMA nodes ("
+
node_ids_str
+
"). Memory allocations will be "
"strictly bound to these NUMA nodes."
);
}
}
numa_set_strict
(
1
);
numa_free_nodemask
(
mask
);
numa_free_nodemask
(
src_mask
);
}
else
{
TORCH_WARN
(
"numa_parse_nodestring or numa_get_membind failed. errno: "
+
TORCH_WARN
(
"numa_parse_nodestring or numa_get_run_node_mask failed. errno: "
+
std
::
to_string
(
errno
));
}
}
...
...
@@ -138,4 +163,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
return
ss
.
str
();
}
#endif
#endif // VLLM_NUMA_DISABLED
namespace
cpu_utils
{
ScratchPadManager
::
ScratchPadManager
()
:
size_
(
0
),
ptr_
(
nullptr
)
{
this
->
realloc
(
allocation_unit
*
128
);
}
void
ScratchPadManager
::
realloc
(
size_t
new_size
)
{
new_size
=
round
(
new_size
);
if
(
new_size
>
size_
)
{
if
(
ptr_
!=
nullptr
)
{
std
::
free
(
ptr_
);
}
ptr_
=
std
::
aligned_alloc
(
64
,
new_size
);
size_
=
new_size
;
}
}
ScratchPadManager
*
ScratchPadManager
::
get_scratchpad_manager
()
{
static
ScratchPadManager
manager
;
return
&
manager
;
}
}
// namespace cpu_utils
csrc/cpu/utils.hpp
View file @
7e63ef82
...
...
@@ -2,19 +2,24 @@
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#include <ATen/cpu/Utils.h>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "cpu/cpu_types.hpp"
namespace
cpu_utils
{
enum
class
ISA
{
AMX
,
VEC
};
inline
ISA
get_isa
(
const
std
::
string
&
isa
)
{
if
(
isa
==
"amx"
)
{
return
ISA
::
AMX
;
}
else
if
(
isa
==
"vec"
)
{
return
ISA
::
VEC
;
}
else
{
TORCH_CHECK
(
false
,
"Invalid isa type: "
+
isa
);
}
}
template
<
typename
T
>
struct
VecTypeTrait
{
using
vec_t
=
void
;
...
...
@@ -32,10 +37,12 @@ struct VecTypeTrait<c10::BFloat16> {
};
#endif
#if !defined(__powerpc__)
template
<
>
struct
VecTypeTrait
<
c10
::
Half
>
{
using
vec_t
=
vec_op
::
FP16Vec16
;
};
#endif
struct
Counter
{
std
::
atomic
<
int64_t
>
counter
;
...
...
@@ -48,26 +55,66 @@ struct Counter {
int64_t
acquire_counter
()
{
return
counter
++
;
}
};
inline
int64_t
get_l2_size
()
{
inline
int64_t
get_
available_
l2_size
()
{
static
int64_t
size
=
[]()
{
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t
l2_cache_size
=
0
;
size_t
len
=
sizeof
(
l2_cache_size
);
if
(
sysctlbyname
(
"hw.l2cachesize"
,
&
l2_cache_size
,
&
len
,
NULL
,
0
)
==
0
&&
l2_cache_size
>
0
)
{
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return
128LL
*
1024
>>
1
;
// use 50% of 128KB
#else
long
l2_cache_size
=
sysconf
(
_SC_LEVEL2_CACHE_SIZE
);
assert
(
l2_cache_size
!=
-
1
);
const
uint32_t
l2_cache_size
=
at
::
cpu
::
L2_cache_size
();
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
#endif
}();
return
size
;
}
template
<
int32_t
alignment_v
,
typename
T
>
inline
T
round_up
(
T
size
)
{
T
alignment
=
alignment_v
;
return
(((
size
+
alignment
-
1
)
/
alignment
)
*
alignment
);
}
template
<
int32_t
alignment_v
,
typename
T
>
inline
T
round_down
(
T
size
)
{
T
alignment
=
alignment_v
;
return
(
size
/
alignment
)
*
alignment
;
}
template
<
typename
T
>
inline
void
print_logits
(
const
char
*
name
,
T
*
ptr
,
int32_t
row
,
int32_t
col
,
int32_t
stride
)
{
std
::
stringstream
ss
;
ss
<<
std
::
fixed
<<
std
::
setprecision
(
5
)
<<
name
<<
": [
\n
"
;
auto
*
curr_logits_buffer
=
ptr
;
for
(
int32_t
m
=
0
;
m
<
row
;
++
m
)
{
for
(
int32_t
n
=
0
;
n
<
col
;
++
n
)
{
ss
<<
curr_logits_buffer
[
n
]
<<
", "
;
}
ss
<<
"
\n
"
;
curr_logits_buffer
+=
stride
;
}
ss
<<
"]
\n
"
;
std
::
printf
(
"%s"
,
ss
.
str
().
c_str
());
}
class
ScratchPadManager
{
public:
static
constexpr
size_t
allocation_unit
=
4
*
1024
;
// 4KB
static
ScratchPadManager
*
get_scratchpad_manager
();
ScratchPadManager
();
template
<
typename
T
>
T
*
get_data
()
{
return
reinterpret_cast
<
T
*>
(
ptr_
);
}
static
size_t
round
(
size_t
size
)
{
return
((
size
+
allocation_unit
-
1
)
/
allocation_unit
)
*
allocation_unit
;
}
void
realloc
(
size_t
new_size
);
private:
size_t
size_
;
void
*
ptr_
;
};
}
// namespace cpu_utils
#endif
csrc/cumem_allocator.cpp
View file @
7e63ef82
...
...
@@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop
.
location
.
id
=
device
;
prop
.
allocFlags
.
compressionType
=
CU_MEM_ALLOCATION_COMP_NONE
;
#ifndef USE_ROCM
int
flag
=
0
;
CUDA_CHECK
(
cuDeviceGetAttribute
(
&
flag
,
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED
,
device
));
if
(
flag
)
{
// support GPUDirect RDMA if possible
prop
.
allocFlags
.
gpuDirectRDMACapable
=
1
;
}
#endif
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK
(
cuMemCreate
(
p_memHandle
,
size
,
&
prop
,
0
));
...
...
csrc/fused_qknorm_rope_kernel.cu
View file @
7e63ef82
...
...
@@ -107,7 +107,8 @@ __global__ void fusedQKNormRopeKernel(
void
const
*
k_weight_void
,
// RMSNorm weights for key
void
const
*
cos_sin_cache_void
,
// Pre-computed cos/sin cache
int64_t
const
*
position_ids
,
// Position IDs for RoPE
int
const
num_tokens
// Number of tokens
int
const
num_tokens
,
// Number of tokens
int
const
rotary_dim
// Dimension for RoPE
)
{
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if
constexpr
((
std
::
is_same_v
<
scalar_t_in
,
c10
::
BFloat16
>
)
||
...
...
@@ -227,22 +228,24 @@ __global__ void fusedQKNormRopeKernel(
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache
const
*
cache_ptr
=
cos_sin_cache
+
pos_id
*
head
_dim
;
int
const
embed_dim
=
head
_dim
/
2
;
T_cache
const
*
cache_ptr
=
cos_sin_cache
+
pos_id
*
rotary
_dim
;
int
const
embed_dim
=
rotary
_dim
/
2
;
T_cache
const
*
cos_ptr
=
cache_ptr
;
T_cache
const
*
sin_ptr
=
cache_ptr
+
embed_dim
;
int
const
rotary_lanes
=
rotary_dim
/
numElemsPerThread
;
// rotary range
if
(
laneId
<
rotary_lanes
)
{
if
constexpr
(
interleave
)
{
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
/
2
;
++
i
)
{
int
const
idx0
=
2
*
i
;
int
const
idx1
=
2
*
i
+
1
;
// Global dimension index in the head
int
const
dim_idx
=
laneId
*
numElemsPerThread
+
idx0
;
float
const
val0
=
elements
[
idx0
];
float
const
val1
=
elements
[
idx1
];
int
const
dim_idx
=
laneId
*
numElemsPerThread
+
idx0
;
int
const
half_dim
=
dim_idx
/
2
;
float
const
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
...
...
@@ -255,19 +258,20 @@ __global__ void fusedQKNormRopeKernel(
}
else
{
// Before data exchange with in warp, we need to sync.
__syncwarp
();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
int
pairOffset
=
(
rotary_dim
/
2
)
/
numElemsPerThread
;
// Get the data from the other half of the warp. Use pre-computed
// cos/sin values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
elements2
[
i
]
=
__shfl_xor_sync
(
FINAL_MASK
,
elements
[
i
],
16
);
if
(
laneId
<
16
)
{
elements2
[
i
]
=
__shfl_xor_sync
(
FINAL_MASK
,
elements
[
i
],
pairOffset
);
if
(
laneId
<
pairOffset
)
{
elements2
[
i
]
=
-
elements2
[
i
];
}
int
dim_idx
=
laneId
*
numElemsPerThread
+
i
;
dim_idx
=
(
dim_idx
*
2
)
%
head_dim
;
dim_idx
=
(
dim_idx
*
2
)
%
rotary_dim
;
int
half_dim
=
dim_idx
/
2
;
// Use pre-computed cos/sin from cache
float
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
float
sin_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
sin_ptr
+
half_dim
));
...
...
@@ -276,7 +280,7 @@ __global__ void fusedQKNormRopeKernel(
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp
();
}
}
// Store.
{
vec_T
vec
;
...
...
@@ -312,10 +316,10 @@ template <typename scalar_t_in, typename scalar_t_cache>
void
launchFusedQKNormRope
(
void
*
qkv
,
int
const
num_tokens
,
int
const
num_heads_q
,
int
const
num_heads_k
,
int
const
num_heads_v
,
int
const
head_dim
,
floa
t
const
eps
,
void
const
*
q_weight
,
void
const
*
k
_weight
,
void
const
*
cos_sin_cache
,
bool
const
interleave
,
int64_t
const
*
position_ids
,
cudaStream_t
stream
)
{
in
t
const
rotary_dim
,
float
const
eps
,
void
const
*
q
_weight
,
void
const
*
k_weight
,
void
const
*
cos_sin_cache
,
bool
const
interleave
,
int64_t
const
*
position_ids
,
cudaStream_t
stream
)
{
constexpr
int
blockSize
=
256
;
int
const
warpsPerBlock
=
blockSize
/
32
;
...
...
@@ -332,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
64
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
,
rotary_dim
);
});
break
;
case
128
:
...
...
@@ -340,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
128
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
,
rotary_dim
);
});
break
;
case
256
:
...
...
@@ -348,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
256
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
,
rotary_dim
);
});
break
;
default:
...
...
@@ -392,8 +396,11 @@ void fused_qk_norm_rope(
"Query weights size must match head dimension"
);
TORCH_CHECK
(
k_weight
.
size
(
0
)
==
head_dim
,
"Key weights size must match head dimension"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
==
head_dim
,
"Cos/sin cache dimension must match head_dim"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
%
2
==
0
,
"rotary_dim must be even"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
<=
head_dim
,
"rotary_dim must be less than or equal to head_dim"
);
TORCH_CHECK
(
qkv
.
scalar_type
()
==
q_weight
.
scalar_type
()
&&
qkv
.
scalar_type
()
==
k_weight
.
scalar_type
(),
"qkv, q_weight and k_weight must have the same dtype"
);
...
...
@@ -419,7 +426,8 @@ void fused_qk_norm_rope(
qkv
.
data_ptr
(),
static_cast
<
int
>
(
num_tokens
),
static_cast
<
int
>
(
num_heads_q
),
static_cast
<
int
>
(
num_heads_k
),
static_cast
<
int
>
(
num_heads_v
),
static_cast
<
int
>
(
head_dim
),
static_cast
<
float
>
(
eps
),
q_weight
.
data_ptr
(),
k_weight
.
data_ptr
(),
static_cast
<
int
>
(
cos_sin_cache
.
size
(
1
)),
static_cast
<
float
>
(
eps
),
q_weight
.
data_ptr
(),
k_weight
.
data_ptr
(),
cos_sin_cache
.
data_ptr
(),
!
is_neox
,
reinterpret_cast
<
int64_t
const
*>
(
position_ids
.
data_ptr
()),
stream
);
...
...
csrc/moe/grouped_topk_kernels.cu
View file @
7e63ef82
...
...
@@ -446,15 +446,19 @@ __device__ inline T apply_sigmoid(T val) {
template
<
ScoringFunc
SF
,
typename
T
>
__device__
inline
T
apply_scoring
(
T
val
)
{
if
constexpr
(
SF
==
SCORING_SIGMOID
)
{
if
constexpr
(
SF
==
SCORING_NONE
)
{
return
val
;
}
else
if
constexpr
(
SF
==
SCORING_SIGMOID
)
{
return
apply_sigmoid
(
val
);
}
else
{
static_assert
(
SF
==
SCORING_NONE
||
SF
==
SCORING_SIGMOID
,
"Unsupported ScoringFunc in apply_scoring"
);
return
val
;
}
}
template
<
typename
T
,
ScoringFunc
SF
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
T
const
*
bias
,
template
<
typename
T
,
typename
BiasT
,
ScoringFunc
SF
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
Bias
T
const
*
bias
,
cg
::
thread_block_tile
<
32
>
const
&
tile
,
int32_t
const
lane_id
,
int
const
num_experts_per_group
)
{
...
...
@@ -465,7 +469,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
if
(
num_experts_per_group
>
WARP_SIZE
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
value
=
apply_scoring
<
SF
>
(
input
[
i
]);
value
=
value
+
bias
[
i
];
value
=
value
+
static_cast
<
T
>
(
bias
[
i
]
)
;
if
(
value
>
largest
)
{
second_largest
=
largest
;
...
...
@@ -477,7 +481,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
}
else
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
value
=
apply_scoring
<
SF
>
(
input
[
i
]);
value
=
value
+
bias
[
i
];
value
=
value
+
static_cast
<
T
>
(
bias
[
i
]
)
;
largest
=
value
;
}
}
...
...
@@ -499,8 +503,8 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
}
}
template
<
typename
T
,
ScoringFunc
SF
>
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
T
const
*
bias
,
template
<
typename
T
,
typename
BiasT
,
ScoringFunc
SF
>
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
Bias
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_cases
,
int64_t
const
n_group
,
...
...
@@ -513,7 +517,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
input
+=
case_id
*
num_experts_per_group
;
// bias is per expert group, offset to current group
int32_t
group_id
=
case_id
%
n_group
;
T
const
*
group_bias
=
bias
+
group_id
*
num_experts_per_group
;
Bias
T
const
*
group_bias
=
bias
+
group_id
*
num_experts_per_group
;
output
+=
case_id
;
cg
::
thread_block
block
=
cg
::
this_thread_block
();
...
...
@@ -522,7 +526,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
topk_with_k2
<
T
,
SF
>
(
output
,
input
,
group_bias
,
tile
,
lane_id
,
topk_with_k2
<
T
,
BiasT
,
SF
>
(
output
,
input
,
group_bias
,
tile
,
lane_id
,
num_experts_per_group
);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
...
...
@@ -530,10 +534,11 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#endif
}
template
<
typename
T
,
typename
IdxT
,
ScoringFunc
SF
,
int
NGroup
=
-
1
>
template
<
typename
T
,
typename
BiasT
,
typename
IdxT
,
ScoringFunc
SF
,
int
NGroup
=
-
1
>
__global__
void
group_idx_and_topk_idx_kernel
(
T
*
scores
,
T
const
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
Bias
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
renormalize
,
double
routed_scaling_factor
)
{
...
...
@@ -619,7 +624,7 @@ __global__ void group_idx_and_topk_idx_kernel(
T
input
=
scores
[
offset
+
i
];
if
(
is_finite
(
input
))
{
T
score
=
apply_scoring
<
SF
>
(
input
);
candidates
=
score
+
bias
[
offset
+
i
];
candidates
=
score
+
static_cast
<
T
>
(
bias
[
offset
+
i
]
)
;
}
}
queue
.
add
(
candidates
,
offset
+
i
);
...
...
@@ -670,10 +675,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if
(
case_id
<
num_tokens
)
{
if
(
if_proceed_next_topk
)
{
float
scale
=
routed_scaling_factor
;
if
(
renormalize
)
{
scale
/=
topk_sum
;
}
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
WARP_SIZE
)
{
float
base
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
]);
float
value
=
renormalize
?
(
base
/
topk_sum
*
routed_scaling_factor
)
:
(
base
*
routed_scaling_factor
);
float
value
=
base
*
scale
;
topk_indices
[
i
]
=
s_topk_idx
[
i
];
topk_values
[
i
]
=
value
;
}
...
...
@@ -691,10 +699,10 @@ __global__ void group_idx_and_topk_idx_kernel(
#endif
}
template
<
typename
T
,
typename
IdxT
,
ScoringFunc
SF
>
template
<
typename
T
,
typename
BiasT
,
typename
IdxT
,
ScoringFunc
SF
>
inline
void
launch_group_idx_and_topk_kernel
(
cudaLaunchConfig_t
const
&
config
,
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
float
*
topk_values
,
IdxT
*
topk_indices
,
Bias
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
const
renormalize
,
...
...
@@ -708,36 +716,36 @@ inline void launch_group_idx_and_topk_kernel(
switch
(
n_group
)
{
case
4
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
4
>
);
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
4
>
);
break
;
}
case
8
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
8
>
);
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
8
>
);
break
;
}
case
16
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
16
>
);
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
16
>
);
break
;
}
case
32
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
32
>
);
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
,
32
>
);
break
;
}
default:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
>
);
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
BiasT
,
IdxT
,
SF
>
);
break
;
}
}
}
template
<
typename
T
,
typename
IdxT
>
template
<
typename
T
,
typename
BiasT
,
typename
IdxT
>
void
invokeNoAuxTc
(
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_
expert
s
,
int64_t
const
n
_group
,
int64_t
const
topk
_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
int
const
scoring_func
,
bool
enable_pdl
=
false
,
cudaStream_t
const
stream
=
0
)
{
IdxT
*
topk_indices
,
Bias
T
const
*
bias
,
int64_t
const
num_
token
s
,
int64_t
const
n
um_experts
,
int64_t
const
n
_group
,
int64_t
const
topk
_group
,
int64_t
const
topk
,
bool
const
renormalize
,
double
const
routed_scaling_factor
,
int
const
scoring_func
,
bool
enable_pdl
=
false
,
cudaStream_t
const
stream
=
0
)
{
int64_t
num_cases
=
num_tokens
*
n_group
;
int64_t
topk_with_k2_num_blocks
=
(
num_cases
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
cudaLaunchConfig_t
config
;
...
...
@@ -758,12 +766,12 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
};
switch
(
sf
)
{
case
SCORING_NONE
:
{
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
SCORING_NONE
>
;
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
BiasT
,
SCORING_NONE
>
;
launch_topk_with_k2
(
kernel_instance1
);
break
;
}
case
SCORING_SIGMOID
:
{
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
SCORING_SIGMOID
>
;
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
BiasT
,
SCORING_SIGMOID
>
;
launch_topk_with_k2
(
kernel_instance1
);
break
;
}
...
...
@@ -787,14 +795,14 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
config
.
attrs
=
attrs
;
switch
(
sf
)
{
case
SCORING_NONE
:
{
launch_group_idx_and_topk_kernel
<
T
,
IdxT
,
SCORING_NONE
>
(
launch_group_idx_and_topk_kernel
<
T
,
BiasT
,
IdxT
,
SCORING_NONE
>
(
config
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
break
;
}
case
SCORING_SIGMOID
:
{
launch_group_idx_and_topk_kernel
<
T
,
IdxT
,
SCORING_SIGMOID
>
(
launch_group_idx_and_topk_kernel
<
T
,
BiasT
,
IdxT
,
SCORING_SIGMOID
>
(
config
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
...
...
@@ -805,17 +813,23 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
}
}
#define INSTANTIATE_NOAUX_TC(T, IdxT)
\
template void invokeNoAuxTc<T, IdxT>(
\
#define INSTANTIATE_NOAUX_TC(T,
BiasT,
IdxT) \
template void invokeNoAuxTc<T,
BiasT,
IdxT>( \
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts,
\
Bias
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC
(
float
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
half
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
__nv_bfloat16
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
float
,
float
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
float
,
half
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
float
,
__nv_bfloat16
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
half
,
float
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
half
,
half
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
half
,
__nv_bfloat16
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
__nv_bfloat16
,
float
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
__nv_bfloat16
,
half
,
int32_t
);
INSTANTIATE_NOAUX_TC
(
__nv_bfloat16
,
__nv_bfloat16
,
int32_t
);
}
// end namespace moe
}
// namespace vllm
...
...
@@ -824,6 +838,7 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
int64_t
topk
,
bool
renormalize
,
double
routed_scaling_factor
,
torch
::
Tensor
const
&
bias
,
int64_t
scoring_func
=
0
)
{
auto
data_type
=
scores
.
scalar_type
();
auto
bias_type
=
bias
.
scalar_type
();
auto
input_size
=
scores
.
sizes
();
int64_t
num_tokens
=
input_size
[
0
];
int64_t
num_experts
=
input_size
[
1
];
...
...
@@ -847,39 +862,62 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
(
scores
.
get_device
());
#define LAUNCH_KERNEL(T, IdxT) \
do { \
switch (bias_type) { \
case torch::kFloat16: \
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kFloat32: \
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kBFloat16: \
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
default: \
throw std::invalid_argument( \
"Invalid bias dtype, only supports float16, float32, and " \
"bfloat16"); \
break; \
} \
} while (0)
switch
(
data_type
)
{
case
torch
::
kFloat16
:
// Handle Float16
vllm
::
moe
::
invokeNoAuxTc
<
half
,
int32_t
>
(
reinterpret_cast
<
half
*>
(
scores
.
mutable_data_ptr
()),
reinterpret_cast
<
half
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
half
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
LAUNCH_KERNEL
(
half
,
int32_t
);
break
;
case
torch
::
kFloat32
:
// Handle Float32
vllm
::
moe
::
invokeNoAuxTc
<
float
,
int32_t
>
(
reinterpret_cast
<
float
*>
(
scores
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
float
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
LAUNCH_KERNEL
(
float
,
int32_t
);
break
;
case
torch
::
kBFloat16
:
// Handle BFloat16
vllm
::
moe
::
invokeNoAuxTc
<
__nv_bfloat16
,
int32_t
>
(
reinterpret_cast
<
__nv_bfloat16
*>
(
scores
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
*>
(
group_scores
.
mutable_data_ptr
()),
reinterpret_cast
<
float
*>
(
topk_values
.
mutable_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
topk_indices
.
mutable_data_ptr
()),
reinterpret_cast
<
__nv_bfloat16
const
*>
(
bias
.
data_ptr
()),
num_tokens
,
num_experts
,
n_group
,
topk_group
,
topk
,
renormalize
,
routed_scaling_factor
,
static_cast
<
int
>
(
scoring_func
),
false
,
stream
);
LAUNCH_KERNEL
(
__nv_bfloat16
,
int32_t
);
break
;
default:
// Handle other data types
...
...
@@ -887,5 +925,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
"Invalid dtype, only supports float16, float32, and bfloat16"
);
break
;
}
#undef LAUNCH_KERNEL
return
{
topk_values
,
topk_indices
};
}
csrc/moe/marlin_moe_wna16/.gitignore
View file @
7e63ef82
sm*_kernel_*.cu
kernel_selector.h
kernel_*.cu
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
7e63ef82
...
...
@@ -10,6 +10,8 @@ import jinja2
ARCHS
=
[]
SUPPORT_FP8
=
False
SUPPORT_SM75
=
False
SUPPORT_SM80
=
False
for
arch
in
sys
.
argv
[
1
].
split
(
","
):
arch
=
arch
[:
arch
.
index
(
"."
)
+
2
].
replace
(
"."
,
""
)
arch
=
int
(
arch
)
...
...
@@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration.
if
arch
in
[
89
,
120
]:
SUPPORT_FP8
=
True
if
arch
>=
80
:
SUPPORT_SM80
=
True
if
arch
==
75
:
SUPPORT_SM75
=
True
FILE_HEAD_COMMENT
=
"""
// auto generated by generate_kernels.py
...
...
@@ -157,6 +163,7 @@ def remove_old_kernels():
def
generate_new_kernels
():
result_dict
=
{}
sm_75_result_dict
=
{}
for
quant_config
in
QUANT_CONFIGS
:
c_types
=
quant_config
.
get
(
"c_type"
,
[
"kFloat16"
,
"kBFloat16"
])
...
...
@@ -174,6 +181,8 @@ def generate_new_kernels():
s_type
=
quant_config
.
get
(
"s_type"
,
c_type
)
if
(
a_type
,
b_type
,
c_type
)
not
in
result_dict
:
result_dict
[(
a_type
,
b_type
,
c_type
)]
=
[]
if
a_type
in
[
"kFloat16"
,
"kS8"
]
and
c_type
==
"kFloat16"
:
sm_75_result_dict
[(
a_type
,
b_type
,
c_type
)]
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
all_group_blocks
,
all_m_blocks
,
all_thread_configs
...
...
@@ -197,17 +206,25 @@ def generate_new_kernels():
"thread_k_blocks"
:
thread_k
//
16
,
"thread_n_blocks"
:
thread_n
//
16
,
"m_block_size_8"
:
"true"
if
m_blocks
==
0.5
else
"false"
,
"stages"
:
"pipe_stages"
,
"stages"
:
4
,
"group_blocks"
:
group_blocks
,
"is_zp_float"
:
"false"
,
}
if
SUPPORT_SM80
:
result_dict
[(
a_type
,
b_type
,
c_type
)].
append
(
config
)
if
(
a_type
,
b_type
,
c_type
)
in
sm_75_result_dict
and
SUPPORT_SM75
:
config_sm75
=
config
.
copy
()
config_sm75
[
"stages"
]
=
2
sm_75_result_dict
[(
a_type
,
b_type
,
c_type
)].
append
(
config_sm75
)
kernel_selector_str
=
FILE_HEAD_COMMENT
for
(
a_type
,
b_type
,
c_type
),
config_list
in
result_dict
.
items
():
for
result_dict_tmp
in
[
result_dict
,
sm_75_result_dict
]:
for
(
a_type
,
b_type
,
c_type
),
config_list
in
result_dict_tmp
.
items
():
all_template_str_list
=
[]
if
not
config_list
:
continue
for
config
in
config_list
:
s_type
=
config
[
"s_type"
]
template_str
=
jinja2
.
Template
(
TEMPLATE
).
render
(
...
...
@@ -229,6 +246,7 @@ def generate_new_kernels():
f
"thread_n_blocks ==
{
config
[
'thread_n_blocks'
]
}
"
,
f
"thread_k_blocks ==
{
config
[
'thread_k_blocks'
]
}
"
,
f
"m_block_size_8 ==
{
config
[
'm_block_size_8'
]
}
"
,
f
"stages ==
{
config
[
'stages'
]
}
"
,
f
"group_blocks ==
{
config
[
'group_blocks'
]
}
"
,
f
"is_zp_float ==
{
config
[
'is_zp_float'
]
}
"
,
]
...
...
@@ -262,6 +280,8 @@ def generate_new_kernels():
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
if
a_type
==
"kFE4M3fn"
:
filename
=
f
"sm89_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
elif
result_dict_tmp
is
sm_75_result_dict
:
filename
=
f
"sm75_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
else
:
filename
=
f
"sm80_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
...
...
csrc/moe/marlin_moe_wna16/kernel.h
View file @
7e63ef82
...
...
@@ -19,8 +19,8 @@
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights,
bool is_ep,
int num_groups, int prob_m,
\
int
prob_n, int
prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool mul_topk_weights, int num_groups, int prob_m,
int prob_n,
\
int prob_k, int *locks, bool has_bias, bool use_atomic_add,
\
bool use_fp32_reduce
namespace
MARLIN_NAMESPACE_NAME
{
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
7e63ef82
...
...
@@ -26,6 +26,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -35,7 +36,7 @@
namespace
MARLIN_NAMESPACE_NAME
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <
80
0
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <
75
0
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
b_type_id
,
// weight MarlinScalarType id
...
...
@@ -70,7 +71,6 @@ __global__ void Marlin(
const
float
*
__restrict__
topk_weights_ptr
,
// moe top weights
int
top_k
,
// num of experts per token
bool
mul_topk_weights
,
// mul topk weights or not
bool
is_ep
,
// expert parallelism
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
...
...
@@ -84,146 +84,6 @@ __global__ void Marlin(
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
vllm
::
ScalarTypeId
type_id
,
int
k_size
=
16
>
__device__
inline
void
mma
(
const
typename
MarlinScalarType
<
type_id
>::
FragA
&
a_frag
,
const
typename
MarlinScalarType
<
type_id
>::
FragB
&
frag_b
,
typename
MarlinScalarType
<
type_id
>::
FragC
&
frag_c
,
int
idx
=
0
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
using
scalar_t
=
typename
MarlinScalarType
<
type_id
>::
scalar_t
;
if
constexpr
(
k_size
==
16
)
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
idx
*
2
]),
"r"
(
a
[
idx
*
2
+
1
]),
"r"
(
b
[
idx
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
a
[
idx
*
2
]),
"r"
(
a
[
idx
*
2
+
1
]),
"r"
(
b
[
idx
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
else
if
(
k_size
==
32
)
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
}
template
<
vllm
::
ScalarTypeId
type_id
,
int
k_size
=
16
>
__device__
inline
void
mma_trans
(
const
typename
MarlinScalarType
<
type_id
>::
FragA
&
a_frag
,
const
typename
MarlinScalarType
<
type_id
>::
FragB
&
frag_b
,
const
typename
MarlinScalarType
<
type_id
>::
FragB
&
frag_b2
,
typename
MarlinScalarType
<
type_id
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
const
uint32_t
*
b2
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b2
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
using
scalar_t
=
typename
MarlinScalarType
<
type_id
>::
scalar_t
;
if
constexpr
(
k_size
==
16
)
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
a
[
0
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
a
[
0
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
else
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm
volatile
(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
#else
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
#endif
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
int
count
,
vllm
::
ScalarTypeId
type_id
>
...
...
@@ -412,7 +272,6 @@ __global__ void Marlin(
const
float
*
__restrict__
topk_weights_ptr
,
// moe top weights
int
top_k
,
// num of experts per token
bool
mul_topk_weights
,
// mul topk weights or not
bool
is_ep
,
// expert parallelism
int
num_groups
,
// number of scale groups per output channel
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
...
...
@@ -439,9 +298,20 @@ __global__ void Marlin(
if
constexpr
(
a_type_id
==
vllm
::
kFE4M3fn
.
id
())
return
;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if
constexpr
(
a_type_id
!=
vllm
::
kFloat16
.
id
()
&&
a_type_id
!=
vllm
::
kS8
.
id
())
return
;
#endif
int
num_tokens_past_padded
=
num_tokens_past_padded_ptr
[
0
];
constexpr
int
moe_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr
bool
use_fp16_accum
=
a_type_id
==
vllm
::
kFloat16
.
id
();
#else
constexpr
bool
use_fp16_accum
=
false
;
#endif
using
Adtype
=
MarlinScalarType
<
a_type_id
>
;
using
Cdtype
=
MarlinScalarType
<
c_type_id
>
;
...
...
@@ -504,14 +374,6 @@ __global__ void Marlin(
// parallel: num valid moe blocks
int
parallel
=
num_tokens_past_padded
/
moe_block_size
;
int
num_valid_blocks
=
parallel
;
if
(
is_ep
)
{
for
(
int
i
=
0
;
i
<
parallel
;
i
++
)
{
if
(
expert_ids_ptr
[
i
]
==
-
1
)
num_valid_blocks
--
;
}
}
int
num_invalid_blocks
=
parallel
-
num_valid_blocks
;
parallel
=
num_valid_blocks
;
int
k_tiles
=
prob_k
/
16
/
thread_k_blocks
;
int
n_tiles
=
prob_n
/
16
/
thread_n_blocks
;
...
...
@@ -618,7 +480,22 @@ __global__ void Marlin(
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if
constexpr
(
moe_block_size
>=
16
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
16
);
if
constexpr
(
moe_block_size
>=
8
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
8
);
if
constexpr
(
moe_block_size
>=
4
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
4
);
if
constexpr
(
moe_block_size
>=
2
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
2
);
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
1
);
block_num_valid_tokens
=
local_count
;
#else
block_num_valid_tokens
=
__reduce_add_sync
(
0xffffffff
,
local_count
);
#endif
if
(
lane_id
==
0
)
reinterpret_cast
<
int
*>
(
sh_new
)[
0
]
=
block_num_valid_tokens
;
...
...
@@ -651,22 +528,8 @@ __global__ void Marlin(
if
(
par_id
>=
parallel
)
return
;
old_expert_id
=
expert_id
;
if
(
num_invalid_blocks
>
0
)
{
int
skip_count
=
par_id
;
for
(
int
i
=
0
;
i
<
num_tokens_past_padded
/
moe_block_size
;
i
++
)
{
expert_id
=
expert_ids_ptr
[
i
];
if
(
expert_id
!=
-
1
)
{
if
(
skip_count
==
0
)
{
block_id
=
i
;
break
;
};
skip_count
--
;
};
}
}
else
{
block_id
=
par_id
;
expert_id
=
expert_ids_ptr
[
block_id
];
}
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
uint16_t
val
=
global_scale_ptr
[
expert_id
];
...
...
@@ -1018,10 +881,6 @@ __global__ void Marlin(
constexpr
int
sh_s_size
=
has_act_order
?
(
act_s_max_num_groups
*
s_sh_stride
)
:
(
stages
*
s_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert
(
thread_m_blocks
*
16
*
thread_n_blocks
*
16
/
8
<=
stages
*
b_sh_stage
);
int4
*
sh_a
=
sh_s
+
sh_s_size
;
// Register storage for double buffer of shared memory reads.
...
...
@@ -1545,11 +1404,13 @@ __global__ void Marlin(
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
if
constexpr
(
m_block_size_8
)
{
mma_trans
<
a_type_id
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_b1
,
mma_trans
<
a_type_id
,
use_fp16_accum
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_b1
,
frag_c
[
i
][
j
][
0
]);
}
else
{
mma
<
a_type_id
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
a_type_id
>
(
frag_a
[
k2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
mma
<
a_type_id
,
use_fp16_accum
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
a_type_id
,
use_fp16_accum
>
(
frag_a
[
k2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
}
...
...
@@ -1583,9 +1444,11 @@ __global__ void Marlin(
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
<
a_type_id
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
0
],
mma
<
a_type_id
,
false
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
0
],
(
group_blocks
==
-
1
?
frag_c
:
frag_c_tmp
)[
i
][
j
][
0
]);
mma
<
a_type_id
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
1
],
mma
<
a_type_id
,
false
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
1
],
(
group_blocks
==
-
1
?
frag_c
:
frag_c_tmp
)[
i
][
j
][
1
]);
}
...
...
@@ -2132,6 +1995,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
// convert fp16 accum to fp32 for reduction
if
constexpr
(
use_fp16_accum
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
(
thread_m_blocks
*
(
is_a_8bit
?
2
:
4
)
*
2
);
i
++
)
{
float
*
frag_c_part_float
=
reinterpret_cast
<
float
*>
(
frag_c
)
+
i
*
4
;
scalar_t
*
frag_c_part_half
=
reinterpret_cast
<
scalar_t
*>
(
frag_c_part_float
);
#pragma unroll
for
(
int
i
=
3
;
i
>=
0
;
i
--
)
{
frag_c_part_float
[
i
]
=
Cdtype
::
num2float
(
frag_c_part_half
[
i
]);
}
}
}
if
constexpr
(
is_a_8bit
)
{
float
frag_a_s
[
2
*
thread_m_blocks
];
...
...
csrc/moe/marlin_moe_wna16/ops.cu
View file @
7e63ef82
...
...
@@ -142,7 +142,7 @@ typedef struct {
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
)
{
bool
has_act_order
,
bool
is_k_full
,
int
stages
)
{
bool
cache_scales_chunk
=
has_act_order
&&
!
is_k_full
;
int
tb_n
=
th_config
.
thread_n
;
...
...
@@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if
(
cache_scales_chunk
)
{
int
load_groups
=
tb_groups
*
pipe_
stages
*
2
;
// Chunk size is 2x pipeline over dim K
tb_groups
*
stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
pipe_
stages
;
return
tb_scales
*
stages
;
}
}
...
...
@@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
bool
is_a_8bit
)
{
int
is_zp_float
,
bool
is_a_8bit
,
int
stages
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
...
...
@@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int
sh_block_meta_size
=
tb_m
*
16
;
int
sh_a_size
=
pipe_
stages
*
(
tb_m
*
tb_k
)
*
(
is_a_8bit
?
1
:
2
);
int
sh_b_size
=
pipe_
stages
*
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
int
sh_a_size
=
stages
*
(
tb_m
*
tb_k
)
*
(
is_a_8bit
?
1
:
2
);
int
sh_b_size
=
stages
*
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
int
sh_red_size
=
tb_m
*
(
tb_n
+
8
)
*
2
;
int
sh_bias_size
=
tb_n
*
2
;
int
tmp_size
=
...
...
@@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int
sh_s_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
int
sh_g_idx_size
=
has_act_order
&&
!
is_k_full
?
pipe_
stages
*
tb_k
/
4
:
0
;
group_size
,
has_act_order
,
is_k_full
,
stages
);
int
sh_g_idx_size
=
has_act_order
&&
!
is_k_full
?
stages
*
tb_k
/
4
:
0
;
int
sh_zp_size
=
0
;
if
(
has_zp
)
{
if
(
is_zp_float
)
...
...
@@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int
thread_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
has_zp
,
int
is_zp_float
,
int
max_shared_mem
,
bool
is_a_8bit
)
{
bool
is_a_8bit
,
int
stages
,
int
max_shared_mem
)
{
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
...
...
@@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int
cache_size
=
get_kernel_cache_size
(
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
is_a_8bit
);
is_k_full
,
has_zp
,
is_zp_float
,
is_a_8bit
,
stages
);
return
cache_size
<=
max_shared_mem
;
}
...
...
@@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
const
vllm
::
ScalarType
c_type
,
const
vllm
::
ScalarType
s_type
,
int
thread_m_blocks
,
int
thread_n_blocks
,
int
thread_k_blocks
,
bool
m_block_size_8
,
bool
has_act_order
,
bool
has_zp
,
int
group_blocks
,
int
threads
,
bool
is_zp_float
)
{
int
threads
,
bool
is_zp_float
,
int
stages
)
{
int
num_bits
=
b_type
.
size_bits
();
auto
kernel
=
MarlinDefault
;
...
...
@@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
const
vllm
::
ScalarType
&
c_type
,
const
vllm
::
ScalarType
&
s_type
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_experts
,
int
top_k
,
int
thread_m_blocks
,
bool
m_block_size_8
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
bool
is_zp_float
,
int
max_shared_mem
,
int
s
m
s
,
bool
is_a_8bit
)
{
bool
is_k_full
,
bool
has_zp
,
bool
is_zp_float
,
bool
is_a_8bit
,
int
s
tage
s
,
int
max_shared_mem
,
int
sms
)
{
exec_config_t
exec_cfg
=
exec_config_t
{
1
,
thread_config_t
{
-
1
,
-
1
,
-
1
}};
thread_config_t
*
thread_configs
=
thread_m_blocks
>
1
?
large_batch_thread_configs
...
...
@@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
if
(
!
is_valid_config
(
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
-
512
,
is_a_8bit
))
{
is_k_full
,
has_zp
,
is_zp_float
,
is_a_8bit
,
stages
,
max_shared_mem
-
512
))
{
continue
;
}
int
cache_size
=
get_kernel_cache_size
(
th_config
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
is_a_8bit
);
is_a_8bit
,
stages
);
int
group_blocks
=
0
;
if
(
!
has_act_order
)
{
...
...
@@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
get_marlin_kernel
(
a_type
,
b_type
,
c_type
,
s_type
,
thread_m_blocks
,
th_config
.
thread_n
/
16
,
th_config
.
thread_k
/
16
,
m_block_size_8
,
has_act_order
,
has_zp
,
group_blocks
,
th_config
.
num_threads
,
is_zp_float
);
th_config
.
num_threads
,
is_zp_float
,
stages
);
if
(
kernel
==
MarlinDefault
)
continue
;
...
...
@@ -336,14 +336,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void
*
perm
,
void
*
a_tmp
,
void
*
sorted_token_ids
,
void
*
expert_ids
,
void
*
num_tokens_past_padded
,
void
*
topk_weights
,
int
moe_block_size
,
int
num_experts
,
int
top_k
,
bool
mul_topk_weights
,
bool
is_ep
,
int
prob_
m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
a
_type
,
vllm
::
ScalarType
const
&
b
_type
,
vllm
::
ScalarType
const
&
c
_type
,
vllm
::
ScalarType
const
&
s_type
,
bool
has_bias
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
blocks_per_sm
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
int
top_k
,
bool
mul_topk_weights
,
int
prob_m
,
int
prob_
n
,
int
prob_k
,
void
*
workspace
,
vllm
::
ScalarType
const
&
a_type
,
vllm
::
ScalarType
const
&
b
_type
,
vllm
::
ScalarType
const
&
c
_type
,
vllm
::
ScalarType
const
&
s
_type
,
bool
has_bias
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
blocks_per_sm
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
int
thread_m_blocks
=
div_ceil
(
moe_block_size
,
16
);
bool
m_block_size_8
=
moe_block_size
==
8
;
bool
is_a_8bit
=
a_type
.
size_bits
()
==
8
;
...
...
@@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
dev
);
TORCH_CHECK
(
major_capability
*
10
+
minor_capability
>=
80
,
"marlin kernel only support Ampere or newer GPUs."
);
TORCH_CHECK
(
major_capability
*
10
+
minor_capability
>=
75
,
"marlin kernel only support Turing or newer GPUs."
);
int
stages
=
4
;
if
(
major_capability
==
7
&&
minor_capability
==
5
)
{
stages
=
2
;
TORCH_CHECK
(
a_type
==
vllm
::
kFloat16
||
a_type
==
vllm
::
kS8
,
"Turing only support FP16 or INT8 activation."
);
}
if
(
a_type
==
vllm
::
kFE4M3fn
)
{
TORCH_CHECK
(
major_capability
*
10
+
minor_capability
>=
89
,
"FP8 only support Ada Lovelace or newer GPUs."
);
...
...
@@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg
=
determine_exec_config
(
a_type
,
b_type
,
c_type
,
s_type
,
prob_m
,
prob_n
,
prob_k
,
num_experts
,
top_k
,
thread_m_blocks
,
m_block_size_8
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
,
sm
s
,
is_a_8bit
);
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
is_a_8bit
,
stage
s
,
max_shared_mem
,
sms
);
thread_tfg
=
exec_cfg
.
tb_cfg
;
}
...
...
@@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK
(
is_valid_config
(
thread_tfg
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
max_shared_mem
,
is_a_8bit
),
is_a_8bit
,
stages
,
max_shared_mem
),
"Invalid thread config: thread_m_blocks = "
,
thread_m_blocks
,
", thread_k = "
,
thread_tfg
.
thread_k
,
", thread_n = "
,
thread_tfg
.
thread_n
,
...
...
@@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int
sh_cache_size
=
get_kernel_cache_size
(
thread_tfg
,
m_block_size_8
,
thread_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
has_zp
,
is_zp_float
,
is_a_8bit
);
is_k_full
,
has_zp
,
is_zp_float
,
is_a_8bit
,
stages
);
auto
kernel
=
get_marlin_kernel
(
a_type
,
b_type
,
c_type
,
s_type
,
thread_m_blocks
,
thread_n_blocks
,
thread_k_blocks
,
m_block_size_8
,
has_act_order
,
has_zp
,
group_blocks
,
num_threads
,
is_zp_float
);
num_threads
,
is_zp_float
,
stages
);
if
(
kernel
==
MarlinDefault
)
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
,
prob_m
,
", "
,
prob_n
,
...
...
@@ -517,7 +523,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
kernel
<<<
blocks
,
num_threads
,
max_shared_mem
,
stream
>>>
(
A_ptr
,
B_ptr
,
C_ptr
,
C_tmp_ptr
,
bias_ptr
,
a_s_ptr
,
b_s_ptr
,
g_s_ptr
,
zp_ptr
,
g_idx_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_past_padded_ptr
,
topk_weights_ptr
,
top_k
,
mul_topk_weights
,
is_ep
,
num_groups
,
prob_m
,
topk_weights_ptr
,
top_k
,
mul_topk_weights
,
num_groups
,
prob_m
,
prob_n
,
prob_k
,
locks
,
has_bias
,
use_atomic_add
,
use_fp32_reduce
);
// clang-format on
}
...
...
@@ -535,7 +541,7 @@ torch::Tensor moe_wna16_marlin_gemm(
std
::
optional
<
torch
::
Tensor
>
const
&
perm_or_none
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
sorted_token_ids
,
torch
::
Tensor
&
expert_ids
,
torch
::
Tensor
&
num_tokens_past_padded
,
torch
::
Tensor
&
topk_weights
,
int64_t
moe_block_size
,
int64_t
top_k
,
bool
mul_topk_weights
,
bool
is_ep
,
int64_t
moe_block_size
,
int64_t
top_k
,
bool
mul_topk_weights
,
vllm
::
ScalarTypeId
const
&
b_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
,
int64_t
thread_k
,
int64_t
thread_n
,
...
...
@@ -849,9 +855,9 @@ torch::Tensor moe_wna16_marlin_gemm(
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
sorted_token_ids
.
data_ptr
(),
expert_ids
.
data_ptr
(),
num_tokens_past_padded
.
data_ptr
(),
topk_weights
.
data_ptr
(),
moe_block_size
,
num_experts
,
top_k
,
mul_topk_weights
,
is_ep
,
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
a_type
,
b_type
,
c_type
,
s_type
,
has_bias
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
mul_topk_weights
,
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
a_type
,
b_type
,
c_type
,
s_type
,
has_bias
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
blocks_per_sm
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
...
...
csrc/moe/torch_bindings.cpp
View file @
7e63ef82
...
...
@@ -80,7 +80,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights,
bool is_ep,
int b_type_id,"
"bool mul_topk_weights, int b_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float,"
...
...
csrc/ops.h
View file @
7e63ef82
...
...
@@ -2,6 +2,7 @@
#include <optional>
#include <torch/library.h>
#include <tuple>
#include "core/scalar_type.hpp"
...
...
@@ -280,6 +281,11 @@ void get_cutlass_moe_mm_problem_sizes(
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
);
void
get_cutlass_moe_mm_problem_sizes_from_expert_offsets
(
const
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
n
,
const
int64_t
k
,
const
bool
swap_ab
);
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
...
...
@@ -316,6 +322,12 @@ void scaled_fp4_experts_quant(
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
void
silu_and_mul_scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
void
per_token_group_quant_fp8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
...
...
@@ -350,8 +362,9 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
// void static_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
// std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
...
...
Prev
1
2
3
4
5
6
7
8
9
…
35
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment