Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98229db2
Unverified
Commit
98229db2
authored
Sep 13, 2025
by
Elvir Crnčević
Committed by
GitHub
Sep 13, 2025
Browse files
[Kernels][DP/EP] Optimize Silu Kernel for R1 (#24054)
Signed-off-by:
elvircrn
<
elvircrn@gmail.com
>
parent
dbeee384
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1272 additions
and
131 deletions
+1272
-131
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
+636
-38
csrc/ops.h
csrc/ops.h
+6
-0
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+465
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-0
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
+71
-33
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+87
-60
No files found.
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
View file @
98229db2
This diff is collapsed.
Click to expand it.
csrc/ops.h
View file @
98229db2
...
@@ -133,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
...
@@ -133,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input_global_scale
);
torch
::
Tensor
&
input_global_scale
);
#endif
#endif
void
silu_mul_fp8_quant_deep_gemm_cuda
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
counts
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
int64_t
group_size
,
bool
use_ue8m0
,
int64_t
num_parallel_tokens
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/quantization/activation_kernels.cu
View file @
98229db2
...
@@ -9,6 +9,26 @@
...
@@ -9,6 +9,26 @@
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
#include <c10/util/Float8_e4m3fn.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
typedef
__hip_bfloat162
__nv_bfloat162
;
typedef
__hip_bfloat16
__nv_bfloat16
;
typedef
__hip_bfloat16_raw
__nv_bfloat16_raw
;
typedef
__hip_fp8_e4m3
__nv_fp8_e4m3
;
typedef
__hip_fp8x4_e4m3
__nv_fp8x4_e4m3
;
#endif
#include "core/registration.h"
namespace
vllm
{
namespace
vllm
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -87,6 +107,337 @@ __global__ void act_and_mul_quant_kernel(
...
@@ -87,6 +107,337 @@ __global__ void act_and_mul_quant_kernel(
}
}
}
}
}
}
__device__
__forceinline__
float
silu
(
float
x
)
{
return
(
__fdividef
(
x
,
(
1.
f
+
expf
(
-
x
))));
}
__device__
__forceinline__
float2
silu2
(
float2
x
)
{
return
make_float2
(
silu
(
x
.
x
),
silu
(
x
.
y
));
}
#ifndef USE_ROCM
__device__
__forceinline__
float
warp_max
(
float
v
)
{
static
constexpr
unsigned
FULL_MASK
=
0xffffffffu
;
for
(
int
offset
=
1
;
offset
<
WARP_SIZE
;
offset
*=
2
)
{
v
=
fmaxf
(
v
,
__shfl_xor_sync
(
FULL_MASK
,
v
,
offset
));
}
return
v
;
}
__device__
__forceinline__
__nv_bfloat16
warp_max
(
__nv_bfloat16
v
)
{
static
constexpr
unsigned
FULL_MASK
=
0xffffffffu
;
for
(
int
offset
=
1
;
offset
<
WARP_SIZE
;
offset
*=
2
)
{
v
=
__hmax
(
v
,
__shfl_xor_sync
(
FULL_MASK
,
v
,
offset
));
}
return
v
;
}
#endif
template
<
typename
T
,
typename
U
>
__device__
__forceinline__
void
cp_async4
(
T
*
_smem_ptr
,
const
U
*
_glob_ptr
)
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
auto
smem_ptr
=
reinterpret_cast
<
void
*>
(
_smem_ptr
);
auto
glob_ptr
=
reinterpret_cast
<
const
void
*>
(
_glob_ptr
);
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
#else
_smem_ptr
[
0
]
=
_glob_ptr
[
0
];
#endif
}
__device__
__forceinline__
void
cp_async_fence
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
#else
#endif
}
template
<
int
N
>
__device__
__forceinline__
void
cp_async_wait
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
N
));
#else
#endif
}
template
<
>
__device__
__forceinline__
void
cp_async_wait
<
0
>
()
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
asm
volatile
(
"cp.async.wait_all;
\n
"
::
);
#else
#endif
}
__device__
__forceinline__
float
clip
(
float
v
,
float
mmin
,
float
mmax
)
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
return
fminf
(
mmax
,
fmaxf
(
v
,
mmin
));
#else
#endif
}
__device__
__forceinline__
__nv_bfloat16
clip
(
__nv_bfloat16
v
,
__nv_bfloat16
mmin
,
__nv_bfloat16
mmax
)
{
return
__hmin
(
mmax
,
__hmax
(
v
,
mmin
));
}
__device__
__forceinline__
__nv_bfloat162
clip
(
__nv_bfloat162
v
,
__nv_bfloat162
mmin
,
__nv_bfloat162
mmax
)
{
return
__hmin2
(
mmax
,
__hmax2
(
v
,
mmin
));
}
// We use the following values for fp8 min/max:
// __nv_fp8_e4m3 = (-448, +448)
// __nv_fp8_e4m3uz = (-240.0, +240.0)
// It is currently assumed that only
template
<
class
T
>
constexpr
__nv_bfloat16
get_fp8_max
()
{
static_assert
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fnuz
>
);
if
constexpr
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
)
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
17376
});
}
else
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
17264
});
}
}
template
<
class
T
>
constexpr
__nv_bfloat16
get_fp8_min
()
{
static_assert
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fnuz
>
);
if
constexpr
(
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
)
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
50144
});
}
else
{
return
__nv_bfloat16
(
__nv_bfloat16_raw
{.
x
=
50032
});
}
}
#ifndef USE_ROCM
template
<
typename
fp8_type
,
int32_t
NUM_WARPS
,
typename
Idx_t
,
int
NUM_PARALLEL_TOKENS
,
bool
USE_UE8M0
,
int
GROUP_SIZE
=
128
,
int
NUM_STAGES
=
3
>
__global__
void
silu_mul_fp8_quant_deep_gemm_kernel
(
const
__nv_bfloat16
*
__restrict__
_input
,
fp8_type
*
__restrict__
_y_q
,
float
*
__restrict__
_y_s
,
const
int32_t
*
__restrict__
counts
,
// sizes
int
H
,
int
G
,
// strides (in elements)
Idx_t
stride_i_e
,
Idx_t
stride_i_t
,
Idx_t
stride_i_h
,
Idx_t
stride_yq_e
,
Idx_t
stride_yq_t
,
Idx_t
stride_yq_h
,
Idx_t
stride_ys_e
,
Idx_t
stride_ys_t
,
Idx_t
stride_ys_g
,
Idx_t
stride_counts_e
)
{
static
constexpr
__nv_bfloat16
fp8_min
=
get_fp8_min
<
fp8_type
>
();
static
constexpr
__nv_bfloat16
fp8_max
=
get_fp8_max
<
fp8_type
>
();
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
static
constexpr
__nv_bfloat16
EPS
=
(
__nv_bfloat16_raw
{.
x
=
11996
});
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
static
constexpr
int32_t
BFLOAT16_PER_GROUP
=
8
;
// We split the shared memory in half, corresponding to gate and up matrices:
// [...gate_i, ...up_i] where 0 <= i < stages.
static
constexpr
int32_t
S_NUM_128
=
2u
*
(
GROUP_SIZE
/
BFLOAT16_PER_GROUP
)
*
NUM_WARPS
*
NUM_STAGES
;
static
constexpr
auto
THREAD_COUNT
=
NUM_WARPS
*
WARP_SIZE
;
static
constexpr
int
HALF_THREAD_COUNT
=
THREAD_COUNT
/
2
;
static
constexpr
int32_t
S_NUM_64
=
S_NUM_128
*
2
;
__shared__
__int128_t
__align__
(
16
)
s_buff_128
[
S_NUM_128
];
const
int32_t
tid
=
threadIdx
.
x
;
const
int32_t
warp_id
=
tid
/
WARP_SIZE
;
const
int32_t
lane_id
=
tid
%
WARP_SIZE
;
auto
s_buff_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
s_buff_128
);
// block handles one (expert e, group g)
int32_t
pid
=
blockIdx
.
x
;
int32_t
e
=
pid
/
G
;
int32_t
g
=
pid
%
G
;
const
int32_t
n_tokens
=
counts
[
e
*
stride_counts_e
];
if
(
!
n_tokens
)
{
return
;
// Exit ASAP.
}
const
Idx_t
stride_i_t_128
=
stride_i_t
/
8u
;
int32_t
n_tokens_lower
,
n_tokens_upper
;
// Each block i iterates over tokens of a slice of n_tokens =
// expert_counts[i], with the size of chunk being
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
if
(
n_tokens
<
NUM_PARALLEL_TOKENS
&&
blockIdx
.
y
<
n_tokens
)
{
// Specialize this, but can be likely fused.
if
(
blockIdx
.
y
>=
NUM_PARALLEL_TOKENS
)
{
return
;
}
n_tokens_lower
=
blockIdx
.
y
;
n_tokens_upper
=
blockIdx
.
y
+
1
;
}
else
{
auto
chunk_size
=
n_tokens
/
NUM_PARALLEL_TOKENS
;
auto
residual
=
n_tokens
-
chunk_size
*
NUM_PARALLEL_TOKENS
;
auto
calc_id
=
[
&
](
int32_t
id
)
{
if
(
id
<
residual
)
{
return
min
(
n_tokens
,
id
*
(
chunk_size
+
1
));
}
else
{
return
min
(
n_tokens
,
id
*
chunk_size
+
residual
);
}
};
n_tokens_lower
=
calc_id
(
blockIdx
.
y
);
n_tokens_upper
=
calc_id
(
blockIdx
.
y
+
1
);
}
if
(
n_tokens_lower
>=
n_tokens_upper
)
{
return
;
}
// We do calculations here, using constexpr wherever possible.
const
Idx_t
base_i
=
e
*
stride_i_e
+
NUM_WARPS
*
g
*
GROUP_SIZE
*
stride_i_h
;
const
Idx_t
base_ys
=
e
*
stride_ys_e
+
NUM_WARPS
*
g
*
stride_ys_g
;
const
Idx_t
base_yq
=
e
*
stride_yq_e
+
NUM_WARPS
*
g
*
GROUP_SIZE
*
stride_yq_h
;
Idx_t
gate_off_128
=
(
base_i
/
static_cast
<
Idx_t
>
(
8u
));
auto
input_128_ptr
=
reinterpret_cast
<
const
__int128_t
*>
(
_input
);
auto
gate_128_ptr
=
input_128_ptr
+
gate_off_128
+
(
tid
%
HALF_THREAD_COUNT
)
+
stride_i_t_128
*
n_tokens_lower
;
auto
up_128_ptr
=
gate_128_ptr
+
(
H
*
stride_i_h
)
/
8u
;
auto
y_s_ptr
=
_y_s
+
base_ys
+
warp_id
*
stride_ys_g
+
n_tokens_lower
*
stride_ys_t
;
auto
y_q_ptr
=
_y_q
+
base_yq
+
warp_id
*
GROUP_SIZE
+
stride_yq_t
*
n_tokens_lower
+
4
*
lane_id
;
int32_t
t_load
=
n_tokens_lower
,
load_stage_id
=
0
;
auto
s_buff_gate_load_128
=
s_buff_128
+
(
tid
%
HALF_THREAD_COUNT
);
auto
s_buff_up_load_128
=
s_buff_gate_load_128
+
S_NUM_128
/
2u
;
int32_t
stage_offset
{};
static
constexpr
int32_t
LOAD_STAGE_SIZE
=
(
NUM_WARPS
*
WARP_SIZE
/
2
);
static
constexpr
int32_t
LOAD_STAGE_MOD
=
NUM_STAGES
*
(
NUM_WARPS
*
WARP_SIZE
/
2
);
// Two halves of all threads in a block conduct global loads for gate and up,
// repsectively.
auto
load_and_advance_y_pred
=
[
&
]
{
if
(
t_load
<
n_tokens_upper
)
{
auto
s_gate_stage_128_staged_ptr
=
s_buff_gate_load_128
+
stage_offset
;
auto
s_up_stage_128_staged_ptr
=
s_buff_up_load_128
+
stage_offset
;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
stage_offset
+=
LOAD_STAGE_SIZE
;
stage_offset
%=
LOAD_STAGE_MOD
;
if
(
tid
<
HALF_THREAD_COUNT
)
{
cp_async4
(
s_gate_stage_128_staged_ptr
,
gate_128_ptr
);
gate_128_ptr
+=
stride_i_t_128
;
}
else
{
cp_async4
(
s_up_stage_128_staged_ptr
,
up_128_ptr
);
up_128_ptr
+=
stride_i_t_128
;
}
++
t_load
;
++
load_stage_id
;
}
// We fence even if there is nothing to load to simplify pipelining.
cp_async_fence
();
};
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_STAGES
-
1
;
i
++
)
{
load_and_advance_y_pred
();
}
__int64_t
*
s_gate_ptr
=
reinterpret_cast
<
__int64_t
*>
(
s_buff_compute_32
+
warp_id
*
(
GROUP_SIZE
/
2
))
+
lane_id
;
__int64_t
*
s_up_ptr
=
s_gate_ptr
+
S_NUM_64
/
2
;
static
constexpr
int32_t
STAGE_SIZE
=
(
GROUP_SIZE
*
NUM_WARPS
)
/
4u
;
static
constexpr
int32_t
STAGE_MOD
=
STAGE_SIZE
*
NUM_STAGES
;
int32_t
compute_pipeline_offset_64
=
0
;
for
(
int32_t
t
=
n_tokens_lower
;
t
<
n_tokens_upper
;
++
t
)
{
__nv_bfloat16
y_max_bf16
=
EPS
;
__nv_bfloat162
results_bf162
[
2
];
cp_async_wait
<
NUM_STAGES
-
2
>
();
__syncthreads
();
// We double-buffer pipelined loads so that the next load will
// concurrently run with compute without overwrites.
load_and_advance_y_pred
();
auto
s_gate_compute_64
=
s_gate_ptr
+
compute_pipeline_offset_64
;
auto
s_up_compute_64
=
s_up_ptr
+
compute_pipeline_offset_64
;
// STAGE_SIZE must also be constexpr!
compute_pipeline_offset_64
+=
STAGE_SIZE
;
compute_pipeline_offset_64
%=
STAGE_MOD
;
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
__int64_t
gate64
=
*
s_gate_compute_64
;
__nv_bfloat162
*
s_gate_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
gate64
);
__int64_t
up64
=
*
s_up_compute_64
;
__nv_bfloat162
*
s_up_compute_32
=
reinterpret_cast
<
__nv_bfloat162
*>
(
&
up64
);
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
// For silu, we make sure that div is emitted.
float2
gate
=
silu2
(
__bfloat1622float2
(
s_gate_compute_32
[
i
]));
results_bf162
[
i
]
=
__float22bfloat162_rn
(
gate
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
results_bf162
[
i
]
=
__hmul2
(
results_bf162
[
i
],
s_up_compute_32
[
i
]);
}
auto
_y_max2
=
__hmax2
(
__habs2
(
results_bf162
[
0
]),
__habs2
(
results_bf162
[
1
]));
y_max_bf16
=
__hmax
(
_y_max2
.
x
,
_y_max2
.
y
);
// An entire group is assigned to a single warp, so a simple warp reduce
// is used.
__nv_bfloat16
y_s
=
warp_max
(
y_max_bf16
)
/
fp8_max
;
if
constexpr
(
USE_UE8M0
)
{
y_s
=
hexp2
(
hceil
(
hlog2
(
y_s
)));
}
auto
inv_y
=
__float2bfloat16_rn
(
1.
f
)
/
y_s
;
auto
y_s2
=
make_bfloat162
(
inv_y
,
inv_y
);
#pragma unroll
for
(
int32_t
i
=
0
;
i
<
2
;
++
i
)
{
results_bf162
[
i
]
=
clip
(
__hmul2
(
results_bf162
[
i
],
y_s2
),
__bfloat162bfloat162
(
fp8_min
),
__bfloat162bfloat162
(
fp8_max
));
}
auto
fp8x4
=
__nv_fp8x4_e4m3
(
results_bf162
[
0
],
results_bf162
[
1
]);
*
reinterpret_cast
<
__nv_fp8x4_e4m3
*>
(
y_q_ptr
)
=
fp8x4
;
y_q_ptr
+=
stride_yq_t
;
if
(
lane_id
==
0
)
{
*
y_s_ptr
=
y_s
;
y_s_ptr
+=
stride_ys_t
;
}
}
}
#endif
}
// namespace vllm
}
// namespace vllm
// Launch activation, gating, and quantize kernel.
// Launch activation, gating, and quantize kernel.
...
@@ -119,3 +470,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
...
@@ -119,3 +470,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
TORCH_CHECK
(
input
.
size
(
-
1
)
%
2
==
0
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
2
==
0
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
}
void
silu_mul_fp8_quant_deep_gemm_cuda
(
const
at
::
Tensor
&
input
,
// (E, T, 2*H)
const
at
::
Tensor
&
counts
,
// (E)
at
::
Tensor
&
y_q
,
// (E, T, H) [OUT]
at
::
Tensor
&
y_s
,
// (E, T, H//group_size) [OUT]
int64_t
group_size
,
bool
use_ue8m0
,
int64_t
num_parallel_tokens
)
{
#ifndef USE_ROCM
// This kernel relies heavily on cp.async and fp8 support.
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
y_q
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
y_s
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
256
==
0
);
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
TORCH_CHECK
(
1
<=
num_parallel_tokens
&&
num_parallel_tokens
<=
64
);
TORCH_CHECK
(
!
(
num_parallel_tokens
&
(
num_parallel_tokens
-
1
)));
using
Idx_t
=
int64_t
;
Idx_t
E
=
input
.
size
(
0
);
Idx_t
T
=
input
.
size
(
1
);
Idx_t
H
=
input
.
size
(
2
)
/
2
;
Idx_t
stride_i_e
=
input
.
stride
(
0
);
Idx_t
stride_i_t
=
input
.
stride
(
1
);
Idx_t
stride_i_h
=
input
.
stride
(
2
);
Idx_t
stride_yq_e
=
y_q
.
stride
(
0
);
Idx_t
stride_yq_t
=
y_q
.
stride
(
1
);
Idx_t
stride_yq_h
=
y_q
.
stride
(
2
);
Idx_t
stride_ys_e
=
y_s
.
stride
(
0
);
Idx_t
stride_ys_t
=
y_s
.
stride
(
1
);
Idx_t
stride_ys_g
=
y_s
.
stride
(
2
);
Idx_t
stride_counts_e
=
counts
.
stride
(
0
);
static
constexpr
int
GROUP_SIZE
=
128
;
#define KERNEL_FN \
if (use_ue8m0) { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, true> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
} else { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, false> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
}
#define KERNEL_CALL_H \
if (H % (4 * GROUP_SIZE) == 0) { \
static constexpr int NUM_WARPS = 4; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
} else { \
static constexpr int NUM_WARPS = 1; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
}
#define KERNEL_CALL_TOP_LEVEL \
if (num_parallel_tokens == 1) { \
static constexpr int NUM_PARALLEL_TOKENS = 1; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 2) { \
static constexpr int NUM_PARALLEL_TOKENS = 2; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 4) { \
static constexpr int NUM_PARALLEL_TOKENS = 4; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 8) { \
static constexpr int NUM_PARALLEL_TOKENS = 8; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 16) { \
static constexpr int NUM_PARALLEL_TOKENS = 16; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 32) { \
static constexpr int NUM_PARALLEL_TOKENS = 32; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 64) { \
static constexpr int NUM_PARALLEL_TOKENS = 64; \
KERNEL_CALL_H \
}
Idx_t
G
;
dim3
block
,
grid
;
auto
populate_launch_params
=
[
&
](
int
num_warps
,
int
_num_parallel_tokens
)
{
G
=
H
/
Idx_t
(
group_size
*
num_warps
);
grid
=
dim3
(
E
*
G
,
_num_parallel_tokens
);
block
=
dim3
(
num_warps
*
WARP_SIZE
);
};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
VLLM_DISPATCH_FP8_TYPES
(
y_q
.
scalar_type
(),
"silu_mul_fp8_quant_deep_gemm_kernel"
,
[
&
]
{
KERNEL_CALL_TOP_LEVEL
});
#endif
}
csrc/torch_bindings.cpp
View file @
98229db2
...
@@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#define stride_tag
#define stride_tag
#endif
#endif
ops
.
def
(
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s, int group_size, "
"bool use_ue8m0, int num_parallel_tokens) -> ()"
);
ops
.
impl
(
"silu_mul_fp8_quant_deep_gemm_cuda"
,
torch
::
kCUDA
,
&
silu_mul_fp8_quant_deep_gemm_cuda
);
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
ops
.
impl
(
"weak_ref_tensor"
,
torch
::
kCUDA
,
&
weak_ref_tensor
);
ops
.
impl
(
"weak_ref_tensor"
,
torch
::
kCUDA
,
&
weak_ref_tensor
);
...
...
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
View file @
98229db2
...
@@ -5,28 +5,52 @@ import pytest
...
@@ -5,28 +5,52 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
silu_mul_fp8_quant_deep_gemm
)
silu_mul_fp8_quant_deep_gemm
_cuda
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
fp8_dtype
=
torch
.
float8_e4m3fn
# (E, T, H, group_size, seed)
CASES
=
[
CASES
=
[
(
1
,
1
,
128
,
64
,
0
),
(
1
,
1
,
128
,
fp8_dtype
),
(
1
,
4
,
128
,
128
,
0
),
(
1
,
4
,
128
,
fp8_dtype
),
(
2
,
4
,
256
,
128
,
0
),
(
2
,
4
,
256
,
fp8_dtype
),
(
32
,
64
,
256
,
128
,
0
),
(
32
,
64
,
256
,
fp8_dtype
),
(
17
,
31
,
768
,
128
,
0
),
(
17
,
31
,
768
,
fp8_dtype
),
(
1
,
1
,
128
*
1
,
fp8_dtype
),
(
1
,
1
,
128
*
2
,
fp8_dtype
),
(
1
,
1
,
128
*
3
,
fp8_dtype
),
(
1
,
1
,
128
*
4
,
fp8_dtype
),
(
8
,
16
,
128
*
1
,
fp8_dtype
),
(
8
,
16
,
128
*
2
,
fp8_dtype
),
(
8
,
16
,
128
*
3
,
fp8_dtype
),
(
8
,
16
,
128
*
4
,
fp8_dtype
),
(
8
,
64
,
7168
,
fp8_dtype
),
(
8
,
128
,
7168
,
fp8_dtype
),
(
8
,
256
,
7168
,
fp8_dtype
),
(
8
,
512
,
7168
,
fp8_dtype
),
(
8
,
1024
,
7168
,
fp8_dtype
),
(
256
,
8
,
7168
,
fp8_dtype
),
(
256
,
16
,
7168
,
fp8_dtype
),
(
256
,
32
,
7168
,
fp8_dtype
),
(
256
,
64
,
7168
,
fp8_dtype
),
# Only add a few fnuz tests to help with long CI times.
(
8
,
512
,
7168
,
torch
.
float8_e4m3fnuz
),
(
8
,
1024
,
7168
,
torch
.
float8_e4m3fnuz
),
]
]
@
pytest
.
mark
.
parametrize
(
"E,T,H,
group_size,seed
"
,
CASES
)
@
pytest
.
mark
.
parametrize
(
"E,T,H,
fp8_type
"
,
CASES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_silu_mul_fp8_quant_deep_gemm
(
E
,
T
,
H
,
group_size
,
seed
):
def
test_silu_mul_fp8_quant_deep_gemm
(
E
,
T
,
H
,
fp8_type
):
current_platform
.
seed_everything
(
seed
)
group_size
=
128
current_platform
.
seed_everything
(
42
)
# Input tensor of shape (E, T, 2*H)
# Input tensor of shape (E, T, 2*H)
y
=
torch
.
randn
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
y
=
torch
.
randn
((
E
,
T
,
2
*
H
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
tokens_per_expert
=
torch
.
randint
(
tokens_per_expert
=
torch
.
randint
(
low
=
0
,
low
=
T
//
2
,
high
=
T
,
high
=
T
,
size
=
(
E
,
),
size
=
(
E
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
...
@@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
)
)
# Run the Triton kernel
# Run the Triton kernel
y_q
,
y_s
=
silu_mul_fp8_quant_deep_gemm
(
y
,
y_q
,
y_s
=
silu_mul_fp8_quant_deep_gemm
_cuda
(
y
,
tokens_per_expert
,
tokens_per_expert
,
group_size
=
group_size
,
group_size
=
group_size
)
eps
=
1e-10
)
# Reference implementation
torch
.
cuda
.
synchronize
()
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
fp8_dtype
)
fp8_max
=
fp8_info
.
max
fp8_max
=
fp8_info
.
max
fp8_min
=
fp8_info
.
min
fp8_min
=
fp8_info
.
min
eps
=
1e-10
eps
=
1e-10
# Compute silu activation and elementwise multiplication
y1
=
y
[...,
:
H
].
float
()
y1
=
y
[...,
:
H
]
y2
=
y
[...,
H
:]
y2
=
y
[...,
H
:]
silu_x
=
y1
*
torch
.
sigmoid
(
y1
)
silu_x
=
y1
*
torch
.
sigmoid
(
y1
)
merged
=
silu_x
*
y2
merged
=
silu_x
*
y2
# Compute reference scales and quantized output, skipping padded tokens
for
e
in
range
(
E
):
for
e
in
range
(
E
):
nt
=
tokens_per_expert
[
e
].
item
()
nt
=
tokens_per_expert
[
e
].
item
()
ref_s
=
torch
.
empty
((
T
,
H
//
group_size
),
ref_s
=
torch
.
empty
((
T
,
cdiv
(
H
,
group_size
)
)
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
device
=
"cuda"
)
ref_q
=
torch
.
empty
((
T
,
H
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
)
ref_q
=
torch
.
empty
((
T
,
H
),
dtype
=
fp8_dtype
,
device
=
"cuda"
)
for
t
in
range
(
nt
):
for
t
in
range
(
nt
):
data
=
merged
[
e
,
t
]
data
=
merged
[
e
,
t
].
float
()
data_grp
=
data
.
view
(
H
//
group_size
,
group_size
)
ref_q_row
=
torch
.
empty_like
(
data
)
# process full groups
n_full_groups
=
H
//
group_size
if
n_full_groups
>
0
:
data_grp
=
data
[:
n_full_groups
*
group_size
].
view
(
n_full_groups
,
group_size
)
amax
=
data_grp
.
abs
().
amax
(
dim
=
1
).
clamp
(
min
=
eps
)
amax
=
data_grp
.
abs
().
amax
(
dim
=
1
).
clamp
(
min
=
eps
)
scale
=
amax
/
fp8_max
scale
=
amax
/
fp8_max
scaled
=
data
[:
n_full_groups
*
group_size
]
/
scale
.
repeat_interleave
(
group_size
)
ref_q_row
[:
n_full_groups
*
group_size
]
=
scaled
.
clamp
(
fp8_min
,
fp8_max
).
to
(
fp8_dtype
)
ref_s
[
t
,
:
n_full_groups
]
=
scale
scaled
=
data
/
scale
.
repeat_interleave
(
group_size
)
# process remainder group
clamped
=
scaled
.
clamp
(
fp8_min
,
fp8_max
)
rem
=
H
%
group_size
q
=
clamped
.
to
(
torch
.
float8_e4m3fn
)
if
rem
>
0
:
data_rem
=
data
[
-
rem
:]
amax
=
data_rem
.
abs
().
amax
().
clamp
(
min
=
eps
)
scale
=
amax
/
fp8_max
scaled
=
data_rem
/
scale
ref_q_row
[
-
rem
:]
=
scaled
.
clamp
(
fp8_min
,
fp8_max
).
to
(
fp8_dtype
)
ref_s
[
t
,
-
1
]
=
scale
ref_s
[
t
]
=
scale
ref_q
[
t
]
=
ref_q_row
ref_q
[
t
]
=
q
y_se
=
y_s
[
e
]
y_se
=
y_s
[
e
]
.
float
()
y_qe
=
y_q
[
e
]
y_qe
=
y_q
[
e
]
.
float
()
torch
.
testing
.
assert_close
(
y_se
[:
nt
],
ref_s
[:
nt
],
atol
=
1e-4
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
y_se
[:
nt
],
ref_s
[:
nt
],
atol
=
1e-4
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
98229db2
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
math
import
log2
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -10,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
...
@@ -10,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
)
TopKWeightAndReduceDelegate
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
(
fp8_m_grouped_gemm_nt_masked
,
from
vllm.utils.deep_gemm
import
(
fp8_m_grouped_gemm_nt_masked
,
is_deep_gemm_e8m0_used
)
is_deep_gemm_e8m0_used
)
...
@@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm(
...
@@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm(
y_q_ptr
,
# fp8 quantized activations (E, T, H)
y_q_ptr
,
# fp8 quantized activations (E, T, H)
y_s_ptr
,
# 16-bit scales (E, T, G)
y_s_ptr
,
# 16-bit scales (E, T, G)
counts_ptr
,
# int32 num tokens per expert (E)
counts_ptr
,
# int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
# Sizes ---------------------------------------------------------------
H
:
tl
.
constexpr
,
# hidden dimension (per output)
H
:
tl
.
constexpr
,
# hidden dimension (per output)
GROUP_SIZE
:
tl
.
constexpr
,
# elements per group (usually 128)
GROUP_SIZE
:
tl
.
constexpr
,
# elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
# Strides for input (elements) ---------------------------------------
stride_i_e
,
stride_i_e
,
stride_i_t
,
stride_i_t
,
stride_i_h
,
stride_i_h
,
# Strides for y_q (elements) -----------------------------------------
# Strides for y_q (elements) -----------------------------------------
stride_yq_e
,
stride_yq_e
,
stride_yq_t
,
stride_yq_t
,
stride_yq_h
,
stride_yq_h
,
# Strides for y_s (elements) -----------------------------------------
# Strides for y_s (elements) -----------------------------------------
stride_ys_e
,
stride_ys_e
,
stride_ys_t
,
stride_ys_t
,
stride_ys_g
,
stride_ys_g
,
# Stride for counts (elements)
# Stride for counts (elements)
stride_counts_e
,
stride_counts_e
,
# Numeric params ------------------------------------------------------
# Numeric params ------------------------------------------------------
eps
:
tl
.
constexpr
,
eps
:
tl
.
constexpr
,
fp8_min
:
tl
.
constexpr
,
fp8_min
:
tl
.
constexpr
,
fp8_max
:
tl
.
constexpr
,
fp8_max
:
tl
.
constexpr
,
use_ue8m0
:
tl
.
constexpr
,
use_ue8m0
:
tl
.
constexpr
,
# Meta ---------------------------------------------------------------
# Meta ---------------------------------------------------------------
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
NUM_STAGES
:
tl
.
constexpr
,
NUM_STAGES
:
tl
.
constexpr
,
...
@@ -101,17 +96,15 @@ def _silu_mul_fp8_quant_deep_gemm(
...
@@ -101,17 +96,15 @@ def _silu_mul_fp8_quant_deep_gemm(
tl
.
store
(
y_s_ptr
+
base_ys_offset
+
t
*
stride_ys_t
,
y_s
)
tl
.
store
(
y_s_ptr
+
base_ys_offset
+
t
*
stride_ys_t
,
y_s
)
def
silu_mul_fp8_quant_deep_gemm
(
def
silu_mul_fp8_quant_deep_gemm
_cuda
(
y
:
torch
.
Tensor
,
# (E, T, 2*H)
y
:
torch
.
Tensor
,
# (E, T, 2*H)
tokens_per_expert
:
torch
.
Tensor
,
# (E,) number of valid tokens per expert
tokens_per_expert
:
torch
.
Tensor
,
# (E,) number of valid tokens per expert
num_parallel_tokens
=
16
,
group_size
:
int
=
128
,
group_size
:
int
=
128
,
eps
:
float
=
1e-10
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
...
@@ -120,22 +113,17 @@ def silu_mul_fp8_quant_deep_gemm(
...
@@ -120,22 +113,17 @@ def silu_mul_fp8_quant_deep_gemm(
E
,
T
,
H2
=
y
.
shape
E
,
T
,
H2
=
y
.
shape
assert
H2
%
2
==
0
,
"last dim of y must be even (2*H)"
assert
H2
%
2
==
0
,
"last dim of y must be even (2*H)"
H
=
H2
//
2
H
=
H2
//
2
G
=
H
//
group_size
G
=
(
H
+
group_size
-
1
)
//
group_size
assert
H
%
group_size
==
0
,
"H must be divisible by group_size"
assert
H
%
8
==
0
,
"H must be divisible by 8"
assert
tokens_per_expert
.
ndim
==
1
and
tokens_per_expert
.
shape
[
0
]
==
E
,
\
assert
group_size
==
128
,
"H must be divisible by 8"
"tokens_per_expert must be shape (E,)"
assert
tokens_per_expert
.
ndim
==
1
and
tokens_per_expert
.
shape
[
0
]
==
E
tokens_per_expert
=
tokens_per_expert
.
to
(
device
=
y
.
device
,
tokens_per_expert
=
tokens_per_expert
.
to
(
device
=
y
.
device
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
# allocate outputs
fp8_dtype
=
torch
.
float8_e4m3fn
fp8_dtype
=
torch
.
float8_e4m3fn
y_q
=
torch
.
empty
((
E
,
T
,
H
),
dtype
=
fp8_dtype
,
device
=
y
.
device
)
y_q
=
torch
.
empty
((
E
,
T
,
H
),
dtype
=
fp8_dtype
,
device
=
y
.
device
)
# strides (elements)
stride_i_e
,
stride_i_t
,
stride_i_h
=
y
.
stride
()
stride_yq_e
,
stride_yq_t
,
stride_yq_h
=
y_q
.
stride
()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e
=
T
*
G
stride_ys_e
=
T
*
G
stride_ys_t
=
1
stride_ys_t
=
1
stride_ys_g
=
T
stride_ys_g
=
T
...
@@ -144,16 +132,56 @@ def silu_mul_fp8_quant_deep_gemm(
...
@@ -144,16 +132,56 @@ def silu_mul_fp8_quant_deep_gemm(
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
y
.
device
)
device
=
y
.
device
)
use_ue8m0
=
is_deep_gemm_e8m0_used
()
if
E
<=
16
:
max_empirical_parallelism
=
64
elif
E
<=
32
:
max_empirical_parallelism
=
16
else
:
max_empirical_parallelism
=
4
# We never want to launch more than Tx number of threads
# This computes the clip.
num_parallel_tokens
=
max
(
1
,
min
(
max_empirical_parallelism
,
2
**
int
(
log2
(
min
(
num_parallel_tokens
,
T
)))))
cuda_arch
=
current_platform
.
get_device_capability
(
device_id
=
y
.
device
.
index
).
to_int
()
if
cuda_arch
>=
80
:
torch
.
ops
.
_C
.
silu_mul_fp8_quant_deep_gemm_cuda
(
y
,
tokens_per_expert
,
y_q
,
y_s
,
group_size
,
use_ue8m0
,
num_parallel_tokens
)
else
:
# Default to triton if not on cuda or if arch is too old
y_q
=
torch
.
empty
((
E
,
T
,
H
),
dtype
=
fp8_dtype
,
device
=
y
.
device
)
stride_cnt_e
=
tokens_per_expert
.
stride
()[
0
]
stride_cnt_e
=
tokens_per_expert
.
stride
()[
0
]
# Static grid over experts and H-groups.
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
# A loop inside the kernel handles the token dim
grid
=
(
E
*
G
,
)
grid
=
(
E
*
G
,
)
# strides (elements)
stride_i_e
,
stride_i_t
,
stride_i_h
=
y
.
stride
()
stride_yq_e
,
stride_yq_t
,
stride_yq_h
=
y_q
.
stride
()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e
=
T
*
G
stride_ys_t
=
1
stride_ys_g
=
T
y_s
=
torch
.
empty_strided
(
(
E
,
T
,
G
),
(
stride_ys_e
,
stride_ys_t
,
stride_ys_g
),
dtype
=
torch
.
float32
,
device
=
y
.
device
,
)
f_info
=
torch
.
finfo
(
fp8_dtype
)
f_info
=
torch
.
finfo
(
fp8_dtype
)
fp8_max
=
f_info
.
max
fp8_max
=
f_info
.
max
fp8_min
=
f_info
.
min
fp8_min
=
f_info
.
min
eps
:
float
=
1e-10
_silu_mul_fp8_quant_deep_gemm
[
grid
](
_silu_mul_fp8_quant_deep_gemm
[
grid
](
y
,
y
,
y_q
,
y_q
,
...
@@ -184,7 +212,6 @@ def silu_mul_fp8_quant_deep_gemm(
...
@@ -184,7 +212,6 @@ def silu_mul_fp8_quant_deep_gemm(
class
BatchedDeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
class
BatchedDeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
# The Deep Gemm kernels only support block size of 128
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE
:
list
[
int
]
=
[
128
,
128
]
DEEPGEMM_BLOCK_SHAPE
:
list
[
int
]
=
[
128
,
128
]
...
@@ -297,8 +324,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -297,8 +324,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
fp8_m_grouped_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
fp8_m_grouped_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
)
workspace1
,
expert_num_tokens
,
expected_m
)
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
(
workspace1
,
a2q
,
a2q_scale
=
silu_mul_fp8_quant_deep_gemm
_cuda
(
expert_num_tokens
)
workspace1
,
expert_num_tokens
)
fp8_m_grouped_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
output
,
fp8_m_grouped_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
output
,
expert_num_tokens
,
expected_m
)
expert_num_tokens
,
expected_m
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment