Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a810671a
Commit
a810671a
authored
Jan 08, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori
parents
86b5aefe
6a09612b
Changes
291
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1295 additions
and
378 deletions
+1295
-378
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+174
-36
csrc/cpu/cpu_arch_macros.h
csrc/cpu/cpu_arch_macros.h
+5
-5
csrc/cpu/cpu_attn_impl.hpp
csrc/cpu/cpu_attn_impl.hpp
+9
-32
csrc/cpu/cpu_fused_moe.cpp
csrc/cpu/cpu_fused_moe.cpp
+727
-0
csrc/cpu/cpu_types_x86.hpp
csrc/cpu/cpu_types_x86.hpp
+8
-0
csrc/cpu/cpu_wna16.cpp
csrc/cpu/cpu_wna16.cpp
+9
-9
csrc/cpu/dnnl_helper.cpp
csrc/cpu/dnnl_helper.cpp
+6
-6
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
+33
-0
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
+38
-0
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
+19
-0
csrc/cpu/scratchpad_manager.cpp
csrc/cpu/scratchpad_manager.cpp
+0
-23
csrc/cpu/scratchpad_manager.h
csrc/cpu/scratchpad_manager.h
+0
-31
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+24
-0
csrc/cpu/utils.cpp
csrc/cpu/utils.cpp
+24
-2
csrc/cpu/utils.hpp
csrc/cpu/utils.hpp
+67
-22
csrc/cumem_allocator.cpp
csrc/cumem_allocator.cpp
+10
-0
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+10
-3
csrc/moe/marlin_moe_wna16/.gitignore
csrc/moe/marlin_moe_wna16/.gitignore
+1
-0
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+76
-56
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+55
-153
No files found.
csrc/activation_kernels.cu
View file @
a810671a
...
...
@@ -15,19 +15,61 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
const
scalar_t
&
y
)
{
return
act_first
?
ACT_FN
(
x
)
*
y
:
x
*
ACT_FN
(
y
);
}
// Activation and gating kernel template.
// Check if all pointers are 16-byte aligned for int4 vectorized access
__device__
__forceinline__
bool
is_16byte_aligned
(
const
void
*
ptr
)
{
return
(
reinterpret_cast
<
uintptr_t
>
(
ptr
)
&
15
)
==
0
;
}
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
bool
act_first
>
__global__
void
act_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
x
,
y
);
const
scalar_t
*
x_ptr
=
input
+
token_idx
*
2
*
d
;
const
scalar_t
*
y_ptr
=
x_ptr
+
d
;
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
// Check alignment for 128-bit vectorized access.
// All three pointers must be 16-byte aligned for safe int4 operations.
const
bool
aligned
=
is_16byte_aligned
(
x_ptr
)
&&
is_16byte_aligned
(
y_ptr
)
&&
is_16byte_aligned
(
out_ptr
);
if
(
aligned
&&
d
>=
VEC_SIZE
)
{
// Fast path: 128-bit vectorized loop
const
int4
*
x_vec
=
reinterpret_cast
<
const
int4
*>
(
x_ptr
);
const
int4
*
y_vec
=
reinterpret_cast
<
const
int4
*>
(
y_ptr
);
int4
*
out_vec
=
reinterpret_cast
<
int4
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
VEC_SIZE
;
const
int
vec_end
=
num_vecs
*
VEC_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
x
=
VLLM_LDG
(
&
x_vec
[
i
]),
y
=
VLLM_LDG
(
&
y_vec
[
i
]),
r
;
auto
*
xp
=
reinterpret_cast
<
scalar_t
*>
(
&
x
);
auto
*
yp
=
reinterpret_cast
<
scalar_t
*>
(
&
y
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
rp
[
j
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
xp
[
j
],
yp
[
j
]);
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
VLLM_LDG
(
&
x_ptr
[
i
]),
VLLM_LDG
(
&
y_ptr
[
i
]));
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
x_ptr
[
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
y_ptr
[
idx
]);
out_ptr
[
idx
]
=
compute
<
scalar_t
,
ACT_FN
,
act_first
>
(
x
,
y
);
}
}
}
...
...
@@ -120,50 +162,115 @@ template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__
void
act_and_mul_kernel_with_param
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
int
d
,
const
float
param
)
{
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
,
param
)
*
y
;
const
scalar_t
*
x_ptr
=
input
+
token_idx
*
2
*
d
;
const
scalar_t
*
y_ptr
=
x_ptr
+
d
;
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
// Check alignment for 128-bit vectorized access
const
bool
aligned
=
is_16byte_aligned
(
x_ptr
)
&&
is_16byte_aligned
(
y_ptr
)
&&
is_16byte_aligned
(
out_ptr
);
if
(
aligned
&&
d
>=
VEC_SIZE
)
{
// Fast path: 128-bit vectorized loop
const
int4
*
x_vec
=
reinterpret_cast
<
const
int4
*>
(
x_ptr
);
const
int4
*
y_vec
=
reinterpret_cast
<
const
int4
*>
(
y_ptr
);
int4
*
out_vec
=
reinterpret_cast
<
int4
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
VEC_SIZE
;
const
int
vec_end
=
num_vecs
*
VEC_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
x
=
VLLM_LDG
(
&
x_vec
[
i
]),
y
=
VLLM_LDG
(
&
y_vec
[
i
]),
r
;
auto
*
xp
=
reinterpret_cast
<
scalar_t
*>
(
&
x
);
auto
*
yp
=
reinterpret_cast
<
scalar_t
*>
(
&
y
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
rp
[
j
]
=
ACT_FN
(
xp
[
j
],
param
)
*
yp
[
j
];
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
ACT_FN
(
VLLM_LDG
(
&
x_ptr
[
i
]),
param
)
*
VLLM_LDG
(
&
y_ptr
[
i
]);
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
x_ptr
[
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
y_ptr
[
idx
]);
out_ptr
[
idx
]
=
ACT_FN
(
x
,
param
)
*
y
;
}
}
}
template
<
typename
T
>
__device__
__forceinline__
T
swigluoai_and_mul
(
const
T
&
gate
,
const
T
&
up
,
float
alpha
,
float
limit
)
{
// clamp gate: min=None, max=limit
const
float
gate_f
=
(
float
)
gate
;
const
float
clamped_gate
=
gate_f
>
limit
?
limit
:
gate_f
;
// clamp up: min=-limit, max=limit
const
float
up_f
=
(
float
)
up
;
const
float
clamped_up
=
up_f
>
limit
?
limit
:
(
up_f
<
-
limit
?
-
limit
:
up_f
);
// glu = gate * sigmoid(gate * alpha)
const
float
sigmoid_val
=
1.0
f
/
(
1.0
f
+
expf
(
-
clamped_gate
*
alpha
));
const
float
glu
=
clamped_gate
*
sigmoid_val
;
// (up + 1) * glu
return
(
T
)((
clamped_up
+
1.0
f
)
*
glu
);
// Clamp gate to (-inf, limit] and up to [-limit, limit]
const
float
g
=
fminf
((
float
)
gate
,
limit
);
const
float
u
=
fmaxf
(
fminf
((
float
)
up
,
limit
),
-
limit
);
// glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
return
(
T
)((
u
+
1.0
f
)
*
g
/
(
1.0
f
+
expf
(
-
g
*
alpha
)));
}
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
,
const
scalar_t
&
,
const
float
,
const
float
)>
__global__
void
swigluoai_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2
,
d]
const
scalar_t
*
__restrict__
input
,
// [..., 2
*
d]
(interleaved)
const
int
d
,
const
float
alpha
,
const
float
limit
)
{
// For interleaved data: input has 2*d elements per token (gate/up pairs)
// output has d elements per token
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
constexpr
int
PAIRS
=
VEC_SIZE
/
2
;
// Number of gate/up pairs per int4 load
const
int64_t
token_idx
=
blockIdx
.
x
;
// TODO: Vectorize loads and stores.
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
// gate = x[..., ::2] (even indices)
const
scalar_t
gate
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
2
*
idx
]);
// up = x[..., 1::2] (odd indices)
const
scalar_t
up
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
2
*
idx
+
1
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
gate
,
up
,
alpha
,
limit
);
const
scalar_t
*
in_ptr
=
input
+
token_idx
*
2
*
d
;
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
// Check alignment for 128-bit vectorized access on input.
// For output we use int2 (64-bit) which has 8-byte alignment requirement.
const
bool
in_aligned
=
is_16byte_aligned
(
in_ptr
);
const
bool
out_aligned
=
(
reinterpret_cast
<
uintptr_t
>
(
out_ptr
)
&
7
)
==
0
;
// 8-byte for int2
if
(
in_aligned
&&
out_aligned
&&
d
>=
PAIRS
)
{
// Fast path: vectorized loop
// Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
// Each int2 store writes PAIRS output elements
const
int4
*
in_vec
=
reinterpret_cast
<
const
int4
*>
(
in_ptr
);
int2
*
out_vec
=
reinterpret_cast
<
int2
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
PAIRS
;
const
int
vec_end
=
num_vecs
*
PAIRS
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
v
=
VLLM_LDG
(
&
in_vec
[
i
]);
int2
r
;
auto
*
vp
=
reinterpret_cast
<
scalar_t
*>
(
&
v
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
PAIRS
;
j
++
)
{
rp
[
j
]
=
ACT_FN
(
vp
[
2
*
j
],
vp
[
2
*
j
+
1
],
alpha
,
limit
);
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
ACT_FN
(
VLLM_LDG
(
&
in_ptr
[
2
*
i
]),
VLLM_LDG
(
&
in_ptr
[
2
*
i
+
1
]),
alpha
,
limit
);
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
// gate = x[..., ::2] (even indices)
const
scalar_t
gate
=
VLLM_LDG
(
&
in_ptr
[
2
*
idx
]);
// up = x[..., 1::2] (odd indices)
const
scalar_t
up
=
VLLM_LDG
(
&
in_ptr
[
2
*
idx
+
1
]);
out_ptr
[
idx
]
=
ACT_FN
(
gate
,
up
,
alpha
,
limit
);
}
}
}
...
...
@@ -217,10 +324,41 @@ __global__ void activation_kernel(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., d]
const
int
d
)
{
constexpr
int
VEC_SIZE
=
16
/
sizeof
(
scalar_t
);
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
);
const
scalar_t
*
in_ptr
=
input
+
token_idx
*
d
;
scalar_t
*
out_ptr
=
out
+
token_idx
*
d
;
// Check alignment for 128-bit vectorized access
const
bool
aligned
=
is_16byte_aligned
(
in_ptr
)
&&
is_16byte_aligned
(
out_ptr
);
if
(
aligned
&&
d
>=
VEC_SIZE
)
{
// Fast path: 128-bit vectorized loop
const
int4
*
in_vec
=
reinterpret_cast
<
const
int4
*>
(
in_ptr
);
int4
*
out_vec
=
reinterpret_cast
<
int4
*>
(
out_ptr
);
const
int
num_vecs
=
d
/
VEC_SIZE
;
const
int
vec_end
=
num_vecs
*
VEC_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_vecs
;
i
+=
blockDim
.
x
)
{
int4
v
=
VLLM_LDG
(
&
in_vec
[
i
]),
r
;
auto
*
vp
=
reinterpret_cast
<
scalar_t
*>
(
&
v
);
auto
*
rp
=
reinterpret_cast
<
scalar_t
*>
(
&
r
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
rp
[
j
]
=
ACT_FN
(
vp
[
j
]);
}
out_vec
[
i
]
=
r
;
}
// Scalar cleanup for remaining elements
for
(
int
i
=
vec_end
+
threadIdx
.
x
;
i
<
d
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
ACT_FN
(
VLLM_LDG
(
&
in_ptr
[
i
]));
}
}
else
{
// Scalar fallback for unaligned data or small d
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
in_ptr
[
idx
]);
out_ptr
[
idx
]
=
ACT_FN
(
x
);
}
}
}
...
...
csrc/cpu/cpu_a
ttn
_macros.h
→
csrc/cpu/cpu_a
rch
_macros.h
View file @
a810671a
#ifndef CPU_A
TTN
_MACROS_H
#define CPU_A
TTN
_MACROS_H
#ifndef CPU_A
RCH
_MACROS_H
#define CPU_A
RCH
_MACROS_H
// x86_64
#ifdef __x86_64__
...
...
@@ -26,7 +26,7 @@
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__((
\
auto fast_exp = [&](
const
vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \
__m512 values = vec.reg; \
auto less_ln_flt_min_mask = \
...
...
@@ -98,7 +98,7 @@
poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \
}; \
auto fast_exp = [&](vec_op::FP32Vec16& vec)
\
auto fast_exp = [&](
const
vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \
float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \
...
...
@@ -110,4 +110,4 @@
#endif // __aarch64__
#endif
\ No newline at end of file
#endif
csrc/cpu/cpu_attn_impl.hpp
View file @
a810671a
...
...
@@ -8,10 +8,8 @@
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
#include "utils.hpp"
#include "cpu/cpu_arch_macros.h"
#include "cpu/utils.hpp"
namespace
cpu_attention
{
enum
class
ISA
{
AMX
,
VEC
,
VEC16
,
NEON
};
...
...
@@ -378,12 +376,13 @@ class AttentionScheduler {
static
constexpr
int32_t
MaxQTileIterNum
=
128
;
AttentionScheduler
()
:
available_cache_size_
(
get_available_l2_size
())
{}
AttentionScheduler
()
:
available_cache_size_
(
cpu_utils
::
get_available_l2_size
())
{}
torch
::
Tensor
schedule
(
const
ScheduleInput
&
input
)
const
{
const
bool
casual
=
input
.
casual
;
const
int32_t
thread_num
=
omp_get_max_threads
();
const
int64_t
cache_size
=
get_available_l2_size
();
const
int64_t
cache_size
=
cpu_utils
::
get_available_l2_size
();
const
int32_t
max_num_q_per_iter
=
input
.
max_num_q_per_iter
;
const
int32_t
kv_len_alignment
=
input
.
kv_block_alignment
;
int32_t
q_head_per_kv
=
input
.
num_heads_q
/
input
.
num_heads_kv
;
...
...
@@ -659,7 +658,7 @@ class AttentionScheduler {
metadata_ptr
->
thread_num
+
metadata_ptr
->
reduction_scratchpad_size_per_kv_head
*
(
use_gqa
?
input
.
num_heads_kv
:
input
.
num_heads_q
);
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
->
realloc
(
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
realloc
(
scratchpad_size
);
// metadata_ptr->print();
...
...
@@ -667,7 +666,7 @@ class AttentionScheduler {
// test out of boundary access
// {
// float* cache_ptr =
//
DNNL
ScratchPadManager::get
_dnn
l_scratchpad_manager()->get_data<float>();
//
cpu_utils::
ScratchPadManager::getl_scratchpad_manager()->get_data<float>();
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
// }
...
...
@@ -749,27 +748,6 @@ class AttentionScheduler {
return
std
::
max
(
rounded_tile_size
,
round_size
);
}
static
int64_t
get_available_l2_size
()
{
static
int64_t
size
=
[]()
{
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t
l2_cache_size
=
0
;
size_t
len
=
sizeof
(
l2_cache_size
);
if
(
sysctlbyname
(
"hw.l2cachesize"
,
&
l2_cache_size
,
&
len
,
NULL
,
0
)
==
0
&&
l2_cache_size
>
0
)
{
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return
128LL
*
1024
>>
1
;
// use 50% of 128KB
#else
long
l2_cache_size
=
sysconf
(
_SC_LEVEL2_CACHE_SIZE
);
TORCH_CHECK_NE
(
l2_cache_size
,
-
1
);
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
#endif
}();
return
size
;
}
private:
int64_t
available_cache_size_
;
};
...
...
@@ -1402,7 +1380,7 @@ class AttentionMainLoop {
// init buffers
void
*
scratchpad_ptr
=
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
get_data
<
void
>
();
AttentionScratchPad
buffer_manager
(
thread_id
,
metadata
,
scratchpad_ptr
);
...
...
@@ -1422,8 +1400,7 @@ class AttentionMainLoop {
}
}
const
int64_t
available_cache_size
=
AttentionScheduler
::
get_available_l2_size
();
const
int64_t
available_cache_size
=
cpu_utils
::
get_available_l2_size
();
const
int32_t
default_tile_size
=
AttentionScheduler
::
calcu_default_tile_size
(
available_cache_size
,
head_dim
,
sizeof
(
kv_cache_t
),
...
...
csrc/cpu/cpu_fused_moe.cpp
0 → 100644
View file @
a810671a
This diff is collapsed.
Click to expand it.
csrc/cpu/cpu_types_x86.hpp
View file @
a810671a
...
...
@@ -352,6 +352,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit
FP32Vec16
(
bool
,
void
*
ptr
)
:
reg
((
__m512
)
_mm512_stream_load_si512
(
ptr
))
{}
// strided load
explicit
FP32Vec16
(
const
float
*
ptr
,
INT32Vec16
idx
)
:
reg
(
_mm512_i32gather_ps
(
idx
.
reg
,
ptr
,
4
))
{}
explicit
FP32Vec16
(
__m512
data
)
:
reg
(
data
)
{}
// de-pack 4 bit values
...
...
@@ -408,6 +412,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return
FP32Vec16
(
_mm512_sub_ps
(
reg
,
b
.
reg
));
}
FP32Vec16
operator
-
()
const
{
return
FP32Vec16
(
_mm512_xor_ps
(
reg
,
_mm512_set1_ps
(
-
0.0
f
)));
}
FP32Vec16
operator
/
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
_mm512_div_ps
(
reg
,
b
.
reg
));
}
...
...
csrc/cpu/cpu_wna16.cpp
View file @
a810671a
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "utils.hpp"
#include "cpu/cpu_types.hpp"
#include "cpu/utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
...
...
@@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl(
// a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks
const
int32_t
n_partition_size
=
[
&
]()
{
const
int64_t
cache_size
=
cpu_utils
::
get_l2_size
();
const
int64_t
cache_size
=
cpu_utils
::
get_
available_
l2_size
();
int64_t
ps_cache_limit
=
cache_size
/
(
k_size
*
sizeof
(
scalar_t
));
int64_t
ps_thread_limit
=
n_size
/
thread_num
;
ps_cache_limit
=
...
...
@@ -179,8 +178,8 @@ void cpu_gemm_wna16_impl(
const
int64_t
b_buffer_offset
=
0
;
const
int64_t
c_buffer_offset
=
b_buffer_size
;
const
int64_t
buffer_size
=
b_buffer_size
+
c_buffer_size
;
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
->
realloc
(
buffer_size
*
thread_num
);
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
realloc
(
buffer_size
*
thread_num
);
alignas
(
64
)
cpu_utils
::
Counter
counter
;
cpu_utils
::
Counter
*
counter_ptr
=
&
counter
;
...
...
@@ -190,9 +189,10 @@ void cpu_gemm_wna16_impl(
scalar_t
*
__restrict__
b_buffer
=
nullptr
;
float
*
__restrict__
c_buffer
=
nullptr
;
{
uint8_t
*
buffer_ptr
=
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
()
->
get_data
<
uint8_t
>
()
+
thread_id
*
buffer_size
;
uint8_t
*
buffer_ptr
=
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
get_data
<
uint8_t
>
()
+
thread_id
*
buffer_size
;
b_buffer
=
reinterpret_cast
<
scalar_t
*>
(
buffer_ptr
+
b_buffer_offset
);
c_buffer
=
reinterpret_cast
<
float
*>
(
buffer_ptr
+
c_buffer_offset
);
}
...
...
csrc/cpu/dnnl_helper.cpp
View file @
a810671a
...
...
@@ -4,8 +4,8 @@
#include "common/memory_desc.hpp"
#include "common/memory.hpp"
#include "
dnnl_helper.h
"
#include "
scratchpad_manag
er.h"
#include "
cpu/utils.hpp
"
#include "
cpu/dnnl_help
er.h"
static
dnnl
::
engine
&
default_engine
()
{
static
dnnl
::
engine
engine
(
dnnl
::
engine
::
kind
::
cpu
,
0
);
...
...
@@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto
&&
[
scratchpad_storage
,
scratchpad_mem_desc
]
=
get_runtime_memory_ptr
(
5
);
scratchpad_storage
->
set_data_handle
(
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
->
get_data
<
void
>
());
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
get_data
<
void
>
());
matmul
.
execute
(
default_stream
(),
memory_cache_
);
default_stream
().
wait
();
...
...
@@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return
m_size_cache_
->
get_or_create
(
key
,
[
&
]()
{
dnnl
::
matmul
::
primitive_desc
desc
=
this
->
create_primitive_desc
(
key
,
false
);
auto
manager
=
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
();
auto
manager
=
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
();
manager
->
realloc
(
desc
.
scratchpad_desc
().
get_size
());
return
dnnl
::
matmul
(
desc
);
});
...
...
@@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto
&&
[
scratchpad_storage
,
scratchpad_mem_desc
]
=
get_runtime_memory_ptr
(
3
);
scratchpad_storage
->
set_data_handle
(
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
()
->
get_data
<
void
>
());
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
()
->
get_data
<
void
>
());
matmul
.
execute
(
default_stream
(),
memory_cache_
);
default_stream
().
wait
();
...
...
@@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
}
return
m_size_cache_
->
get_or_create
(
key
,
[
&
]()
{
dnnl
::
matmul
::
primitive_desc
desc
=
this
->
create_primitive_desc
(
key
,
false
);
auto
manager
=
DNNL
ScratchPadManager
::
get_
dnnl_
scratchpad_manager
();
auto
manager
=
cpu_utils
::
ScratchPadManager
::
get_scratchpad_manager
();
manager
->
realloc
(
desc
.
scratchpad_desc
().
get_size
());
return
dnnl
::
matmul
(
desc
);
});
...
...
csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp
View file @
a810671a
...
...
@@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
}
}
static
void
pack_weight
(
const
scalar_t
*
__restrict__
weight
,
scalar_t
*
__restrict__
packed_weight
,
const
int32_t
output_size
,
const
int32_t
input_size
)
{
constexpr
int32_t
elem_num_per_group
=
4
/
sizeof
(
scalar_t
);
TORCH_CHECK_EQ
(
output_size
%
16
,
0
);
TORCH_CHECK_EQ
(
input_size
%
(
16
*
elem_num_per_group
),
0
);
const
int32_t
output_group_num
=
output_size
/
16
;
const
int32_t
input_32b_num
=
input_size
/
elem_num_per_group
;
for
(
int32_t
output_group_idx
=
0
;
output_group_idx
<
output_group_num
;
++
output_group_idx
)
{
const
int32_t
*
__restrict__
weight_32b
=
reinterpret_cast
<
const
int32_t
*>
(
weight
);
int32_t
*
__restrict__
packed_weight_32b
=
reinterpret_cast
<
int32_t
*>
(
packed_weight
);
for
(
int32_t
output_idx
=
0
;
output_idx
<
16
;
++
output_idx
)
{
for
(
int32_t
weight_offset
=
0
,
packed_offset
=
0
;
weight_offset
<
input_32b_num
;
++
weight_offset
,
packed_offset
+=
16
)
{
packed_weight_32b
[
packed_offset
]
=
weight_32b
[
weight_offset
];
}
// update
weight_32b
+=
input_32b_num
;
packed_weight_32b
+=
1
;
}
// update
weight
+=
16
*
input_size
;
packed_weight
+=
16
*
input_size
;
}
}
private:
alignas
(
64
)
__tilecfg
amx_tile_config_
;
int32_t
curr_m_
;
...
...
csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp
View file @
a810671a
...
...
@@ -13,6 +13,9 @@ namespace cpu_micro_gemm {
#define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous
// blocks, means the logical shape of blocks is [16, input_size]. And the actual
// layout of blocks can be ISA-specific.
template
<
cpu_utils
::
ISA
isa
,
typename
scalar_t
>
class
MicroGemm
{
public:
...
...
@@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
curr_d
+=
ldd
;
}
}
template
<
int32_t
n_size
,
typename
scalar_t
>
FORCE_INLINE
void
add_bias_epilogue
(
float
*
c_ptr
,
float
*
d_ptr
,
scalar_t
*
__restrict__
bias_ptr
,
const
int32_t
m
,
const
int64_t
ldc
,
const
int64_t
ldd
)
{
using
scalar_vec_t
=
typename
cpu_utils
::
VecTypeTrait
<
scalar_t
>::
vec_t
;
static_assert
(
n_size
%
16
==
0
);
constexpr
int32_t
n_group_num
=
n_size
/
16
;
static_assert
(
n_group_num
<=
16
);
vec_op
::
FP32Vec16
bias_vecs
[
n_group_num
];
scalar_t
*
__restrict__
curr_bias
=
bias_ptr
;
vec_op
::
unroll_loop
<
int32_t
,
n_group_num
>
([
&
](
int32_t
i
)
{
scalar_vec_t
vec
(
curr_bias
);
bias_vecs
[
i
]
=
vec_op
::
FP32Vec16
(
vec
);
curr_bias
+=
16
;
});
float
*
curr_c
=
c_ptr
;
float
*
curr_d
=
d_ptr
;
for
(
int32_t
i
=
0
;
i
<
m
;
++
i
)
{
float
*
curr_c_iter
=
curr_c
;
float
*
curr_d_iter
=
curr_d
;
vec_op
::
unroll_loop
<
int32_t
,
n_group_num
>
([
&
](
int32_t
n_g_idx
)
{
vec_op
::
FP32Vec16
c_vec_fp32
(
curr_c_iter
);
c_vec_fp32
=
c_vec_fp32
+
bias_vecs
[
n_g_idx
];
c_vec_fp32
.
save
(
curr_d_iter
);
curr_c_iter
+=
16
;
curr_d_iter
+=
16
;
});
curr_c
+=
ldc
;
curr_d
+=
ldd
;
}
}
}
// namespace cpu_micro_gemm
#endif
csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp
View file @
a810671a
...
...
@@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
void
gemm
(
DEFINE_CPU_MICRO_GEMM_PARAMS
)
{
TileGemm82
<
scalar_t
>::
gemm
(
CPU_MICRO_GEMM_PARAMS
);
}
// Note: pack contiguous weight [output_size, input_size] as contiguous
// packed weight [output_size / 16, input_size, 16]
static
void
pack_weight
(
const
scalar_t
*
__restrict__
weight
,
scalar_t
*
__restrict__
packed_weight
,
const
int32_t
output_size
,
const
int32_t
input_size
)
{
TORCH_CHECK_EQ
(
output_size
%
16
,
0
);
for
(
int32_t
o_idx
=
0
;
o_idx
<
output_size
;
++
o_idx
)
{
const
scalar_t
*
__restrict__
curr_weight
=
weight
+
o_idx
*
input_size
;
scalar_t
*
__restrict__
curr_packed_weight
=
packed_weight
+
(
o_idx
/
16
)
*
(
16
*
input_size
)
+
o_idx
%
16
;
for
(
int32_t
i_idx
=
0
;
i_idx
<
input_size
;
++
i_idx
)
{
*
curr_packed_weight
=
*
curr_weight
;
curr_packed_weight
+=
16
;
++
curr_weight
;
}
}
}
};
}
// namespace cpu_micro_gemm
...
...
csrc/cpu/scratchpad_manager.cpp
deleted
100644 → 0
View file @
86b5aefe
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager
::
DNNLScratchPadManager
()
:
size_
(
0
),
ptr_
(
nullptr
)
{
this
->
realloc
(
allocation_unit
*
128
);
}
void
DNNLScratchPadManager
::
realloc
(
size_t
new_size
)
{
new_size
=
round
(
new_size
);
if
(
new_size
>
size_
)
{
if
(
ptr_
!=
nullptr
)
{
std
::
free
(
ptr_
);
}
ptr_
=
std
::
aligned_alloc
(
64
,
new_size
);
size_
=
new_size
;
}
}
DNNLScratchPadManager
*
DNNLScratchPadManager
::
get_dnnl_scratchpad_manager
()
{
static
DNNLScratchPadManager
manager
;
return
&
manager
;
}
csrc/cpu/scratchpad_manager.h
deleted
100644 → 0
View file @
86b5aefe
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class
DNNLScratchPadManager
{
public:
static
constexpr
size_t
allocation_unit
=
4
*
1024
;
// 4KB
static
DNNLScratchPadManager
*
get_dnnl_scratchpad_manager
();
DNNLScratchPadManager
();
template
<
typename
T
>
T
*
get_data
()
{
return
reinterpret_cast
<
T
*>
(
ptr_
);
}
static
size_t
round
(
size_t
size
)
{
return
((
size
+
allocation_unit
-
1
)
/
allocation_unit
)
*
allocation_unit
;
}
void
realloc
(
size_t
new_size
);
private:
size_t
size_
;
void
*
ptr_
;
};
#endif
csrc/cpu/torch_bindings.cpp
View file @
a810671a
...
...
@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
const
std
::
optional
<
torch
::
Tensor
>&
bias
,
const
int64_t
pack_factor
,
const
std
::
string
&
isa_hint
);
void
prepack_moe_weight
(
const
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
packed_weight
,
const
std
::
string
&
isa
);
void
cpu_fused_moe
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
w13
,
const
torch
::
Tensor
&
w2
,
const
std
::
optional
<
torch
::
Tensor
>&
w13_bias
,
const
std
::
optional
<
torch
::
Tensor
>&
w2_bias
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_id
,
const
std
::
string
&
act
,
const
std
::
string
&
isa
);
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
ops
)
{
// vLLM custom ops
...
...
@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"pack_factor, str isa_hint) -> ()"
);
ops
.
impl
(
"cpu_gemm_wna16"
,
torch
::
kCPU
,
&
cpu_gemm_wna16
);
#endif
// fused moe
#if defined(__AVX512F__)
ops
.
def
(
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
"-> ()"
);
ops
.
impl
(
"prepack_moe_weight"
,
torch
::
kCPU
,
&
prepack_moe_weight
);
ops
.
def
(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"str act, str isa) -> ()"
);
ops
.
impl
(
"cpu_fused_moe"
,
torch
::
kCPU
,
&
cpu_fused_moe
);
#endif
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_utils
),
utils
)
{
...
...
csrc/cpu/utils.cpp
View file @
a810671a
...
...
@@ -10,7 +10,7 @@
#define gettid() syscall(SYS_gettid)
#endif
#include "cpu
_type
s.hpp"
#include "cpu
/util
s.hpp"
#ifdef VLLM_NUMA_DISABLED
std
::
string
init_cpu_threads_env
(
const
std
::
string
&
cpu_ids
)
{
...
...
@@ -138,4 +138,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
return
ss
.
str
();
}
#endif
#endif // VLLM_NUMA_DISABLED
namespace
cpu_utils
{
ScratchPadManager
::
ScratchPadManager
()
:
size_
(
0
),
ptr_
(
nullptr
)
{
this
->
realloc
(
allocation_unit
*
128
);
}
void
ScratchPadManager
::
realloc
(
size_t
new_size
)
{
new_size
=
round
(
new_size
);
if
(
new_size
>
size_
)
{
if
(
ptr_
!=
nullptr
)
{
std
::
free
(
ptr_
);
}
ptr_
=
std
::
aligned_alloc
(
64
,
new_size
);
size_
=
new_size
;
}
}
ScratchPadManager
*
ScratchPadManager
::
get_scratchpad_manager
()
{
static
ScratchPadManager
manager
;
return
&
manager
;
}
}
// namespace cpu_utils
csrc/cpu/utils.hpp
View file @
a810671a
...
...
@@ -2,19 +2,24 @@
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#include <ATen/cpu/Utils.h>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "cpu/cpu_types.hpp"
namespace
cpu_utils
{
enum
class
ISA
{
AMX
,
VEC
};
inline
ISA
get_isa
(
const
std
::
string
&
isa
)
{
if
(
isa
==
"amx"
)
{
return
ISA
::
AMX
;
}
else
if
(
isa
==
"vec"
)
{
return
ISA
::
VEC
;
}
else
{
TORCH_CHECK
(
false
,
"Invalid isa type: "
+
isa
);
}
}
template
<
typename
T
>
struct
VecTypeTrait
{
using
vec_t
=
void
;
...
...
@@ -48,26 +53,66 @@ struct Counter {
int64_t
acquire_counter
()
{
return
counter
++
;
}
};
inline
int64_t
get_l2_size
()
{
inline
int64_t
get_
available_
l2_size
()
{
static
int64_t
size
=
[]()
{
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t
l2_cache_size
=
0
;
size_t
len
=
sizeof
(
l2_cache_size
);
if
(
sysctlbyname
(
"hw.l2cachesize"
,
&
l2_cache_size
,
&
len
,
NULL
,
0
)
==
0
&&
l2_cache_size
>
0
)
{
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return
128LL
*
1024
>>
1
;
// use 50% of 128KB
#else
long
l2_cache_size
=
sysconf
(
_SC_LEVEL2_CACHE_SIZE
);
assert
(
l2_cache_size
!=
-
1
);
const
uint32_t
l2_cache_size
=
at
::
cpu
::
L2_cache_size
();
return
l2_cache_size
>>
1
;
// use 50% of L2 cache
#endif
}();
return
size
;
}
template
<
int32_t
alignment_v
,
typename
T
>
inline
T
round_up
(
T
size
)
{
T
alignment
=
alignment_v
;
return
(((
size
+
alignment
-
1
)
/
alignment
)
*
alignment
);
}
template
<
int32_t
alignment_v
,
typename
T
>
inline
T
round_down
(
T
size
)
{
T
alignment
=
alignment_v
;
return
(
size
/
alignment
)
*
alignment
;
}
template
<
typename
T
>
inline
void
print_logits
(
const
char
*
name
,
T
*
ptr
,
int32_t
row
,
int32_t
col
,
int32_t
stride
)
{
std
::
stringstream
ss
;
ss
<<
std
::
fixed
<<
std
::
setprecision
(
5
)
<<
name
<<
": [
\n
"
;
auto
*
curr_logits_buffer
=
ptr
;
for
(
int32_t
m
=
0
;
m
<
row
;
++
m
)
{
for
(
int32_t
n
=
0
;
n
<
col
;
++
n
)
{
ss
<<
curr_logits_buffer
[
n
]
<<
", "
;
}
ss
<<
"
\n
"
;
curr_logits_buffer
+=
stride
;
}
ss
<<
"]
\n
"
;
std
::
printf
(
"%s"
,
ss
.
str
().
c_str
());
}
class
ScratchPadManager
{
public:
static
constexpr
size_t
allocation_unit
=
4
*
1024
;
// 4KB
static
ScratchPadManager
*
get_scratchpad_manager
();
ScratchPadManager
();
template
<
typename
T
>
T
*
get_data
()
{
return
reinterpret_cast
<
T
*>
(
ptr_
);
}
static
size_t
round
(
size_t
size
)
{
return
((
size
+
allocation_unit
-
1
)
/
allocation_unit
)
*
allocation_unit
;
}
void
realloc
(
size_t
new_size
);
private:
size_t
size_
;
void
*
ptr_
;
};
}
// namespace cpu_utils
#endif
csrc/cumem_allocator.cpp
View file @
a810671a
...
...
@@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop
.
location
.
id
=
device
;
prop
.
allocFlags
.
compressionType
=
CU_MEM_ALLOCATION_COMP_NONE
;
#ifndef USE_ROCM
int
flag
=
0
;
CUDA_CHECK
(
cuDeviceGetAttribute
(
&
flag
,
CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED
,
device
));
if
(
flag
)
{
// support GPUDirect RDMA if possible
prop
.
allocFlags
.
gpuDirectRDMACapable
=
1
;
}
#endif
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK
(
cuMemCreate
(
p_memHandle
,
size
,
&
prop
,
0
));
...
...
csrc/moe/grouped_topk_kernels.cu
View file @
a810671a
...
...
@@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
template
<
ScoringFunc
SF
,
typename
T
>
__device__
inline
T
apply_scoring
(
T
val
)
{
if
constexpr
(
SF
==
SCORING_SIGMOID
)
{
if
constexpr
(
SF
==
SCORING_NONE
)
{
return
val
;
}
else
if
constexpr
(
SF
==
SCORING_SIGMOID
)
{
return
apply_sigmoid
(
val
);
}
else
{
static_assert
(
SF
==
SCORING_NONE
||
SF
==
SCORING_SIGMOID
,
"Unsupported ScoringFunc in apply_scoring"
);
return
val
;
}
}
...
...
@@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if
(
case_id
<
num_tokens
)
{
if
(
if_proceed_next_topk
)
{
float
scale
=
routed_scaling_factor
;
if
(
renormalize
)
{
scale
/=
topk_sum
;
}
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
WARP_SIZE
)
{
float
base
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
]);
float
value
=
renormalize
?
(
base
/
topk_sum
*
routed_scaling_factor
)
:
(
base
*
routed_scaling_factor
);
float
value
=
base
*
scale
;
topk_indices
[
i
]
=
s_topk_idx
[
i
];
topk_values
[
i
]
=
value
;
}
...
...
csrc/moe/marlin_moe_wna16/.gitignore
View file @
a810671a
sm*_kernel_*.cu
kernel_selector.h
kernel_*.cu
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
a810671a
...
...
@@ -10,6 +10,8 @@ import jinja2
ARCHS
=
[]
SUPPORT_FP8
=
False
SUPPORT_SM75
=
False
SUPPORT_SM80
=
False
for
arch
in
sys
.
argv
[
1
].
split
(
","
):
arch
=
arch
[:
arch
.
index
(
"."
)
+
2
].
replace
(
"."
,
""
)
arch
=
int
(
arch
)
...
...
@@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration.
if
arch
in
[
89
,
120
]:
SUPPORT_FP8
=
True
if
arch
>=
80
:
SUPPORT_SM80
=
True
if
arch
==
75
:
SUPPORT_SM75
=
True
FILE_HEAD_COMMENT
=
"""
// auto generated by generate_kernels.py
...
...
@@ -157,6 +163,7 @@ def remove_old_kernels():
def
generate_new_kernels
():
result_dict
=
{}
sm_75_result_dict
=
{}
for
quant_config
in
QUANT_CONFIGS
:
c_types
=
quant_config
.
get
(
"c_type"
,
[
"kFloat16"
,
"kBFloat16"
])
...
...
@@ -174,6 +181,8 @@ def generate_new_kernels():
s_type
=
quant_config
.
get
(
"s_type"
,
c_type
)
if
(
a_type
,
b_type
,
c_type
)
not
in
result_dict
:
result_dict
[(
a_type
,
b_type
,
c_type
)]
=
[]
if
a_type
in
[
"kFloat16"
,
"kS8"
]
and
c_type
==
"kFloat16"
:
sm_75_result_dict
[(
a_type
,
b_type
,
c_type
)]
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
all_group_blocks
,
all_m_blocks
,
all_thread_configs
...
...
@@ -197,78 +206,89 @@ def generate_new_kernels():
"thread_k_blocks"
:
thread_k
//
16
,
"thread_n_blocks"
:
thread_n
//
16
,
"m_block_size_8"
:
"true"
if
m_blocks
==
0.5
else
"false"
,
"stages"
:
"pipe_stages"
,
"stages"
:
4
,
"group_blocks"
:
group_blocks
,
"is_zp_float"
:
"false"
,
}
result_dict
[(
a_type
,
b_type
,
c_type
)].
append
(
config
)
if
SUPPORT_SM80
:
result_dict
[(
a_type
,
b_type
,
c_type
)].
append
(
config
)
if
(
a_type
,
b_type
,
c_type
)
in
sm_75_result_dict
and
SUPPORT_SM75
:
config_sm75
=
config
.
copy
()
config_sm75
[
"stages"
]
=
2
sm_75_result_dict
[(
a_type
,
b_type
,
c_type
)].
append
(
config_sm75
)
kernel_selector_str
=
FILE_HEAD_COMMENT
for
(
a_type
,
b_type
,
c_type
),
config_list
in
result_dict
.
items
():
all_template_str_list
=
[]
for
config
in
config_list
:
s_type
=
config
[
"s_type"
]
template_str
=
jinja2
.
Template
(
TEMPLATE
).
render
(
a_type_id
=
f
"vllm::
{
a_type
}
.id()"
,
b_type_id
=
f
"vllm::
{
b_type
}
.id()"
,
c_type_id
=
f
"vllm::
{
c_type
}
.id()"
,
s_type_id
=
f
"vllm::
{
s_type
}
.id()"
,
**
config
,
)
all_template_str_list
.
append
(
template_str
)
conditions
=
[
f
"a_type == vllm::
{
a_type
}
"
,
f
"b_type == vllm::
{
b_type
}
"
,
f
"c_type == vllm::
{
c_type
}
"
,
f
"s_type == vllm::
{
s_type
}
"
,
f
"threads ==
{
config
[
'threads'
]
}
"
,
f
"thread_m_blocks ==
{
config
[
'thread_m_blocks'
]
}
"
,
f
"thread_n_blocks ==
{
config
[
'thread_n_blocks'
]
}
"
,
f
"thread_k_blocks ==
{
config
[
'thread_k_blocks'
]
}
"
,
f
"m_block_size_8 ==
{
config
[
'm_block_size_8'
]
}
"
,
f
"group_blocks ==
{
config
[
'group_blocks'
]
}
"
,
f
"is_zp_float ==
{
config
[
'is_zp_float'
]
}
"
,
]
conditions
=
" && "
.
join
(
conditions
)
if
kernel_selector_str
==
FILE_HEAD_COMMENT
:
kernel_selector_str
+=
f
"if (
{
conditions
}
)
\n
kernel = "
else
:
kernel_selector_str
+=
f
"else if (
{
conditions
}
)
\n
kernel = "
kernel_template2
=
(
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str
+=
(
jinja2
.
Template
(
kernel_template2
).
render
(
for
result_dict_tmp
in
[
result_dict
,
sm_75_result_dict
]:
for
(
a_type
,
b_type
,
c_type
),
config_list
in
result_dict_tmp
.
items
():
all_template_str_list
=
[]
if
not
config_list
:
continue
for
config
in
config_list
:
s_type
=
config
[
"s_type"
]
template_str
=
jinja2
.
Template
(
TEMPLATE
).
render
(
a_type_id
=
f
"vllm::
{
a_type
}
.id()"
,
b_type_id
=
f
"vllm::
{
b_type
}
.id()"
,
c_type_id
=
f
"vllm::
{
c_type
}
.id()"
,
s_type_id
=
f
"vllm::
{
s_type
}
.id()"
,
**
config
,
)
+
"
\n
"
)
all_template_str_list
.
append
(
template_str
)
conditions
=
[
f
"a_type == vllm::
{
a_type
}
"
,
f
"b_type == vllm::
{
b_type
}
"
,
f
"c_type == vllm::
{
c_type
}
"
,
f
"s_type == vllm::
{
s_type
}
"
,
f
"threads ==
{
config
[
'threads'
]
}
"
,
f
"thread_m_blocks ==
{
config
[
'thread_m_blocks'
]
}
"
,
f
"thread_n_blocks ==
{
config
[
'thread_n_blocks'
]
}
"
,
f
"thread_k_blocks ==
{
config
[
'thread_k_blocks'
]
}
"
,
f
"m_block_size_8 ==
{
config
[
'm_block_size_8'
]
}
"
,
f
"stages ==
{
config
[
'stages'
]
}
"
,
f
"group_blocks ==
{
config
[
'group_blocks'
]
}
"
,
f
"is_zp_float ==
{
config
[
'is_zp_float'
]
}
"
,
]
conditions
=
" && "
.
join
(
conditions
)
if
kernel_selector_str
==
FILE_HEAD_COMMENT
:
kernel_selector_str
+=
f
"if (
{
conditions
}
)
\n
kernel = "
else
:
kernel_selector_str
+=
f
"else if (
{
conditions
}
)
\n
kernel = "
kernel_template2
=
(
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
file_content
=
FILE_HEAD
+
"
\n\n
"
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
if
a_type
==
"kFE4M3fn"
:
filename
=
f
"sm89_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
else
:
filename
=
f
"sm80_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
kernel_selector_str
+=
(
jinja2
.
Template
(
kernel_template2
).
render
(
a_type_id
=
f
"vllm::
{
a_type
}
.id()"
,
b_type_id
=
f
"vllm::
{
b_type
}
.id()"
,
c_type_id
=
f
"vllm::
{
c_type
}
.id()"
,
s_type_id
=
f
"vllm::
{
s_type
}
.id()"
,
**
config
,
)
+
"
\n
"
)
file_content
=
FILE_HEAD
+
"
\n\n
"
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
if
a_type
==
"kFE4M3fn"
:
filename
=
f
"sm89_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
elif
result_dict_tmp
is
sm_75_result_dict
:
filename
=
f
"sm75_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
else
:
filename
=
f
"sm80_kernel_
{
a_type
[
1
:]
}
_
{
b_type
[
1
:]
}
_
{
c_type
[
1
:]
}
.cu"
filename
=
filename
.
lower
()
filename
=
filename
.
lower
()
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
),
"w"
)
as
f
:
f
.
write
(
file_content
)
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
),
"w"
)
as
f
:
f
.
write
(
file_content
)
if
not
SUPPORT_FP8
and
kernel_selector_str
!=
FILE_HEAD_COMMENT
:
kernel_selector_str
+=
(
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
a810671a
...
...
@@ -26,6 +26,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -35,7 +36,7 @@
namespace
MARLIN_NAMESPACE_NAME
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <
80
0
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <
75
0
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
b_type_id
,
// weight MarlinScalarType id
...
...
@@ -84,146 +85,6 @@ __global__ void Marlin(
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template
<
vllm
::
ScalarTypeId
type_id
,
int
k_size
=
16
>
__device__
inline
void
mma
(
const
typename
MarlinScalarType
<
type_id
>::
FragA
&
a_frag
,
const
typename
MarlinScalarType
<
type_id
>::
FragB
&
frag_b
,
typename
MarlinScalarType
<
type_id
>::
FragC
&
frag_c
,
int
idx
=
0
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
using
scalar_t
=
typename
MarlinScalarType
<
type_id
>::
scalar_t
;
if
constexpr
(
k_size
==
16
)
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
idx
*
2
]),
"r"
(
a
[
idx
*
2
+
1
]),
"r"
(
b
[
idx
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
a
[
idx
*
2
]),
"r"
(
a
[
idx
*
2
+
1
]),
"r"
(
b
[
idx
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
else
if
(
k_size
==
32
)
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
}
template
<
vllm
::
ScalarTypeId
type_id
,
int
k_size
=
16
>
__device__
inline
void
mma_trans
(
const
typename
MarlinScalarType
<
type_id
>::
FragA
&
a_frag
,
const
typename
MarlinScalarType
<
type_id
>::
FragB
&
frag_b
,
const
typename
MarlinScalarType
<
type_id
>::
FragB
&
frag_b2
,
typename
MarlinScalarType
<
type_id
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
const
uint32_t
*
b2
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b2
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
using
scalar_t
=
typename
MarlinScalarType
<
type_id
>::
scalar_t
;
if
constexpr
(
k_size
==
16
)
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
a
[
0
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
a
[
0
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
else
{
if
constexpr
(
std
::
is_same
<
scalar_t
,
__nv_fp8_e4m3
>::
value
)
{
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm
volatile
(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
#else
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
#endif
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
int8_t
>::
value
)
{
int32_t
*
c
=
reinterpret_cast
<
int32_t
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=r"
(
c
[
0
]),
"=r"
(
c
[
1
]),
"=r"
(
c
[
2
]),
"=r"
(
c
[
3
])
:
"r"
(
b
[
0
]),
"r"
(
b2
[
0
]),
"r"
(
b
[
1
]),
"r"
(
b2
[
1
]),
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
c
[
0
]),
"r"
(
c
[
1
]),
"r"
(
c
[
2
]),
"r"
(
c
[
3
]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template
<
int
count
,
vllm
::
ScalarTypeId
type_id
>
...
...
@@ -439,9 +300,20 @@ __global__ void Marlin(
if
constexpr
(
a_type_id
==
vllm
::
kFE4M3fn
.
id
())
return
;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if
constexpr
(
a_type_id
!=
vllm
::
kFloat16
.
id
()
&&
a_type_id
!=
vllm
::
kS8
.
id
())
return
;
#endif
int
num_tokens_past_padded
=
num_tokens_past_padded_ptr
[
0
];
constexpr
int
moe_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr
bool
use_fp16_accum
=
a_type_id
==
vllm
::
kFloat16
.
id
();
#else
constexpr
bool
use_fp16_accum
=
false
;
#endif
using
Adtype
=
MarlinScalarType
<
a_type_id
>
;
using
Cdtype
=
MarlinScalarType
<
c_type_id
>
;
...
...
@@ -618,7 +490,22 @@ __global__ void Marlin(
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if
constexpr
(
moe_block_size
>=
16
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
16
);
if
constexpr
(
moe_block_size
>=
8
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
8
);
if
constexpr
(
moe_block_size
>=
4
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
4
);
if
constexpr
(
moe_block_size
>=
2
)
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
2
);
local_count
+=
__shfl_down_sync
(
0xFFFFFFFF
,
local_count
,
1
);
block_num_valid_tokens
=
local_count
;
#else
block_num_valid_tokens
=
__reduce_add_sync
(
0xffffffff
,
local_count
);
#endif
if
(
lane_id
==
0
)
reinterpret_cast
<
int
*>
(
sh_new
)[
0
]
=
block_num_valid_tokens
;
...
...
@@ -1018,10 +905,6 @@ __global__ void Marlin(
constexpr
int
sh_s_size
=
has_act_order
?
(
act_s_max_num_groups
*
s_sh_stride
)
:
(
stages
*
s_sh_stage
);
int4
*
sh_s
=
sh_zp
+
(
stages
*
zp_sh_stage
);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert
(
thread_m_blocks
*
16
*
thread_n_blocks
*
16
/
8
<=
stages
*
b_sh_stage
);
int4
*
sh_a
=
sh_s
+
sh_s_size
;
// Register storage for double buffer of shared memory reads.
...
...
@@ -1545,11 +1428,13 @@ __global__ void Marlin(
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
if
constexpr
(
m_block_size_8
)
{
mma_trans
<
a_type_id
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_b1
,
frag_c
[
i
][
j
][
0
]);
mma_trans
<
a_type_id
,
use_fp16_accum
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_b1
,
frag_c
[
i
][
j
][
0
]);
}
else
{
mma
<
a_type_id
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
a_type_id
>
(
frag_a
[
k2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
mma
<
a_type_id
,
use_fp16_accum
>
(
frag_a
[
k2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
a_type_id
,
use_fp16_accum
>
(
frag_a
[
k2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
}
...
...
@@ -1583,10 +1468,12 @@ __global__ void Marlin(
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
<
a_type_id
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
0
],
(
group_blocks
==
-
1
?
frag_c
:
frag_c_tmp
)[
i
][
j
][
0
]);
mma
<
a_type_id
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
1
],
(
group_blocks
==
-
1
?
frag_c
:
frag_c_tmp
)[
i
][
j
][
1
]);
mma
<
a_type_id
,
false
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
0
],
(
group_blocks
==
-
1
?
frag_c
:
frag_c_tmp
)[
i
][
j
][
0
]);
mma
<
a_type_id
,
false
,
32
>
(
frag_a
[
k2
][
i
],
frag_b
[
1
],
(
group_blocks
==
-
1
?
frag_c
:
frag_c_tmp
)[
i
][
j
][
1
]);
}
if
constexpr
(
group_blocks
!=
-
1
)
{
...
...
@@ -2132,6 +2019,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if
(
slice_iters
==
0
)
{
// convert fp16 accum to fp32 for reduction
if
constexpr
(
use_fp16_accum
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
(
thread_m_blocks
*
(
is_a_8bit
?
2
:
4
)
*
2
);
i
++
)
{
float
*
frag_c_part_float
=
reinterpret_cast
<
float
*>
(
frag_c
)
+
i
*
4
;
scalar_t
*
frag_c_part_half
=
reinterpret_cast
<
scalar_t
*>
(
frag_c_part_float
);
#pragma unroll
for
(
int
i
=
3
;
i
>=
0
;
i
--
)
{
frag_c_part_float
[
i
]
=
Cdtype
::
num2float
(
frag_c_part_half
[
i
]);
}
}
}
if
constexpr
(
is_a_8bit
)
{
float
frag_a_s
[
2
*
thread_m_blocks
];
...
...
Prev
1
2
3
4
5
6
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment