Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
706
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1997 additions
and
526 deletions
+1997
-526
benchmarks/kernels/benchmark_moe_align_block_size.py
benchmarks/kernels/benchmark_moe_align_block_size.py
+17
-4
csrc/cpu/cpu_attn_impl.hpp
csrc/cpu/cpu_attn_impl.hpp
+0
-13
csrc/cpu/cpu_attn_macros.h
csrc/cpu/cpu_attn_macros.h
+50
-0
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+18
-0
csrc/moe/grouped_topk_kernels.cu
csrc/moe/grouped_topk_kernels.cu
+128
-47
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+434
-91
csrc/moe/moe_lora_align_sum_kernels.cu
csrc/moe/moe_lora_align_sum_kernels.cu
+0
-174
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+3
-2
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+4
-2
csrc/ops.h
csrc/ops.h
+25
-6
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
+104
-0
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
+483
-0
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
+3
-67
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
+90
-0
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
+11
-0
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
.../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
+124
-20
csrc/quantization/fused_kernels/layernorm_utils.cuh
csrc/quantization/fused_kernels/layernorm_utils.cuh
+308
-94
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
+5
-3
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
+5
-3
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
+185
-0
No files found.
benchmarks/kernels/benchmark_moe_align_block_size.py
View file @
8d75f22e
...
...
@@ -24,12 +24,15 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
num_tokens_range
=
[
1
,
16
,
256
,
4096
]
num_experts_range
=
[
16
,
64
,
224
,
256
,
280
,
512
]
topk_range
=
[
1
,
2
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
ep_size_range
=
[
1
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
,
ep_size_range
)
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
,
"ep_size"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
],
...
...
@@ -38,16 +41,26 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range
args
=
{},
)
)
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
ep_size
,
provider
):
"""Benchmark function for Triton."""
block_size
=
256
torch
.
cuda
.
manual_seed_all
(
0
)
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
e_map
=
None
if
ep_size
!=
1
:
local_e
=
num_experts
//
ep_size
e_ids
=
torch
.
randperm
(
num_experts
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
local_e
]
e_map
=
torch
.
full
((
num_experts
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size
(
topk_ids
,
block_size
,
num_experts
),
lambda
:
moe_align_block_size
(
topk_ids
,
block_size
,
num_experts
,
e_map
,
ignore_invalid_experts
=
True
),
quantiles
=
quantiles
,
)
...
...
csrc/cpu/cpu_attn_impl.hpp
View file @
8d75f22e
...
...
@@ -1246,14 +1246,8 @@ class AttentionMainLoop {
// rescale sum and partial outputs
if
(
need_rescale
)
{
// compute rescale factor
#ifdef DEFINE_FAST_EXP
vec_op
::
FP32Vec16
rescale_factor_vec
(
rescale_factor
);
rescale_factor_vec
=
fast_exp
(
rescale_factor_vec
);
rescale_factor
=
rescale_factor_vec
.
get_last_elem
();
#else
rescale_factor
=
std
::
exp
(
rescale_factor
);
vec_op
::
FP32Vec16
rescale_factor_vec
(
rescale_factor
);
#endif
// rescale sum
new_sum_val
+=
rescale_factor
*
init_sum_val
;
...
...
@@ -1889,15 +1883,8 @@ class AttentionMainLoop {
:
curr_output_buffer
;
float
rescale_factor
=
final_max
>
curr_max
?
curr_max
-
final_max
:
final_max
-
curr_max
;
#ifdef DEFINE_FAST_EXP
vec_op
::
FP32Vec16
rescale_factor_vec
(
rescale_factor
);
rescale_factor_vec
=
fast_exp
(
rescale_factor_vec
);
rescale_factor
=
rescale_factor_vec
.
get_last_elem
();
#else
rescale_factor
=
std
::
exp
(
rescale_factor
);
vec_op
::
FP32Vec16
rescale_factor_vec
(
rescale_factor
);
#endif
local_sum
[
head_idx
]
=
final_max
>
curr_max
?
final_sum
+
rescale_factor
*
curr_sum
...
...
csrc/cpu/cpu_attn_macros.h
View file @
8d75f22e
...
...
@@ -60,4 +60,54 @@
#endif
#ifdef __aarch64__
// Implementation copied from Arm Optimized Routines (expf AdvSIMD)
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
#include <limits>
#define DEFINE_FAST_EXP \
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \
const float ln2_hi = 0x1.62e4p-1f; \
const float ln2_lo = 0x1.7f7d1cp-20f; \
const float c0 = 0x1.0e4020p-7f; \
const float c2 = 0x1.555e66p-3f; \
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \
const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \
const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \
const float32x4_t inf = \
vdupq_n_f32(std::numeric_limits<float>::infinity()); \
const float32x4_t zero = vdupq_n_f32(0.0f); \
auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \
r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \
float32x4_t r2 = vmulq_f32(r, r); \
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \
q = vfmaq_f32(q, p, r2); \
p = vmulq_f32(c4, r); \
float32x4_t poly = vfmaq_f32(p, q, r2); \
poly = vfmaq_f32(scale, poly, scale); \
const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \
const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \
poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \
}; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \
float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \
result.val[1] = neon_expf(vec.reg.val[1]); \
result.val[2] = neon_expf(vec.reg.val[2]); \
result.val[3] = neon_expf(vec.reg.val[3]); \
return vec_op::FP32Vec16(result); \
};
#endif // __aarch64__
#endif
\ No newline at end of file
csrc/dispatch_utils.h
View file @
8d75f22e
...
...
@@ -118,6 +118,24 @@
} \
}
#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
constexpr bool const_expr = true; \
__VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \
...
...
csrc/moe/grouped_topk_kernels.cu
View file @
8d75f22e
...
...
@@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) {
return
cuda_cast
<
T
,
float
>
(
sigmoid_accurate
(
f
));
}
template
<
typename
T
>
template
<
ScoringFunc
SF
,
typename
T
>
__device__
inline
T
apply_scoring
(
T
val
)
{
if
constexpr
(
SF
==
SCORING_SIGMOID
)
{
return
apply_sigmoid
(
val
);
}
else
{
return
val
;
}
}
template
<
typename
T
,
ScoringFunc
SF
>
__device__
void
topk_with_k2
(
T
*
output
,
T
const
*
input
,
T
const
*
bias
,
cg
::
thread_block_tile
<
32
>
const
&
tile
,
int32_t
const
lane_id
,
int
const
num_experts_per_group
,
int
const
scoring_func
)
{
int
const
num_experts_per_group
)
{
// Get the top2 per thread
T
largest
=
neg_inf
<
T
>
();
T
second_largest
=
neg_inf
<
T
>
();
if
(
num_experts_per_group
>
WARP_SIZE
)
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
value
=
input
[
i
];
// Apply scoring function if needed
if
(
scoring_func
==
SCORING_SIGMOID
)
{
value
=
apply_sigmoid
(
value
);
}
T
value
=
apply_scoring
<
SF
>
(
input
[
i
]);
value
=
value
+
bias
[
i
];
if
(
value
>
largest
)
{
...
...
@@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
}
}
else
{
for
(
int
i
=
lane_id
;
i
<
num_experts_per_group
;
i
+=
WARP_SIZE
)
{
T
value
=
input
[
i
];
// Apply scoring function if needed
if
(
scoring_func
==
SCORING_SIGMOID
)
{
value
=
apply_sigmoid
(
value
);
}
T
value
=
apply_scoring
<
SF
>
(
input
[
i
]);
value
=
value
+
bias
[
i
];
largest
=
value
;
}
...
...
@@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
}
}
template
<
typename
T
>
template
<
typename
T
,
ScoringFunc
SF
>
__global__
void
topk_with_k2_kernel
(
T
*
output
,
T
*
input
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
num_cases
,
int64_t
const
n_group
,
int64_t
const
num_experts_per_group
,
int
const
scoring_func
)
{
int64_t
const
num_experts_per_group
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
topk_with_k2
(
output
,
input
,
group_bias
,
tile
,
lane_id
,
num_experts_per_group
,
scoring_func
);
topk_with_k2
<
T
,
SF
>
(
output
,
input
,
group_bias
,
tile
,
lane_id
,
num_experts_per_group
);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
template
<
typename
T
,
typename
IdxT
>
template
<
typename
T
,
typename
IdxT
,
ScoringFunc
SF
,
int
NGroup
=
-
1
>
__global__
void
group_idx_and_topk_idx_kernel
(
T
*
scores
,
T
const
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
renormalize
,
double
routed_scaling_factor
,
int
scoring_func
)
{
double
routed_scaling_factor
)
{
int32_t
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
int32_t
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
int32_t
case_id
=
...
...
@@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel(
topk_values
+=
case_id
*
topk
;
topk_indices
+=
case_id
*
topk
;
constexpr
bool
kUseStaticNGroup
=
(
NGroup
>
0
);
// use int32 to avoid implicit conversion
int32_t
const
n_group_i32
=
kUseStaticNGroup
?
NGroup
:
static_cast
<
int32_t
>
(
n_group
);
int32_t
align_num_experts_per_group
=
warp_topk
::
round_up_to_multiple_of
<
WARP_SIZE
>
(
num_experts_per_group
);
...
...
@@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel(
if
(
case_id
<
num_tokens
)
{
// calculate group_idx
int32_t
target_num_min
=
WARP_SIZE
-
n_group
+
topk_group
;
int32_t
target_num_min
=
WARP_SIZE
-
n_group_i32
+
static_cast
<
int32_t
>
(
topk_group
);
// The check is necessary to avoid abnormal input
if
(
lane_id
<
n_group
&&
is_finite
(
group_scores
[
lane_id
]))
{
if
(
lane_id
<
n_group
_i32
&&
is_finite
(
group_scores
[
lane_id
]))
{
value
=
group_scores
[
lane_id
];
}
int
count_equal_to_top_value
=
WARP_SIZE
-
n_group
;
int
count_equal_to_top_value
=
WARP_SIZE
-
n_group
_i32
;
int
pre_count_equal_to_top_value
=
0
;
// Use loop to find the largset top_group
while
(
count_equal_to_top_value
<
target_num_min
)
{
...
...
@@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel(
int
count_equalto_topkth_group
=
0
;
bool
if_proceed_next_topk
=
topk_group_value
!=
neg_inf
<
T
>
();
if
(
case_id
<
num_tokens
&&
if_proceed_next_topk
)
{
for
(
int
i_group
=
0
;
i_group
<
n_group
;
i_group
++
)
{
auto
process_group
=
[
&
](
int
i_group
)
{
if
((
group_scores
[
i_group
]
>
topk_group_value
)
||
((
group_scores
[
i_group
]
==
topk_group_value
)
&&
(
count_equalto_topkth_group
<
num_equalto_topkth_group
)))
{
...
...
@@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel(
i
+=
WARP_SIZE
)
{
T
candidates
=
neg_inf
<
T
>
();
if
(
i
<
num_experts_per_group
)
{
//
A
pply scoring function (if any) and add bias
//
a
pply scoring function (if any) and add bias
T
input
=
scores
[
offset
+
i
];
if
(
is_finite
(
input
))
{
T
score
=
(
scoring_func
==
SCORING_SIGMOID
)
?
apply_sigmoid
(
input
)
:
input
;
T
score
=
apply_scoring
<
SF
>
(
input
);
candidates
=
score
+
bias
[
offset
+
i
];
}
}
...
...
@@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel(
count_equalto_topkth_group
++
;
}
}
};
if
constexpr
(
kUseStaticNGroup
)
{
#pragma unroll
for
(
int
i_group
=
0
;
i_group
<
NGroup
;
++
i_group
)
{
process_group
(
i_group
);
}
}
else
{
for
(
int
i_group
=
0
;
i_group
<
n_group_i32
;
++
i_group
)
{
process_group
(
i_group
);
}
}
queue
.
done
();
__syncwarp
();
...
...
@@ -646,27 +661,24 @@ __global__ void group_idx_and_topk_idx_kernel(
if
(
i
<
topk
)
{
// Load the score value (without bias) for normalization
T
input
=
scores
[
s_topk_idx
[
i
]];
value
=
(
scoring_func
==
SCORING_SIGMOID
)
?
apply_sigmoid
(
input
)
:
input
;
value
=
apply_scoring
<
SF
>
(
input
);
s_topk_value
[
i
]
=
value
;
}
if
(
renormalize
)
{
topk_sum
+=
cg
::
reduce
(
tile
,
cuda_cast
<
float
,
T
>
(
value
),
cg
::
plus
<
float
>
());
}
}
}
__syncthreads
();
if
(
case_id
<
num_tokens
)
{
if
(
if_proceed_next_topk
)
{
for
(
int
i
=
lane_id
;
i
<
topk
;
i
+=
WARP_SIZE
)
{
float
value
;
if
(
renormalize
)
{
value
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
])
/
topk_sum
*
routed_scaling_factor
;
}
else
{
value
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
])
*
routed_scaling_factor
;
}
float
base
=
cuda_cast
<
float
,
T
>
(
s_topk_value
[
i
]);
float
value
=
renormalize
?
(
base
/
topk_sum
*
routed_scaling_factor
)
:
(
base
*
routed_scaling_factor
);
topk_indices
[
i
]
=
s_topk_idx
[
i
];
topk_values
[
i
]
=
value
;
}
...
...
@@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel(
#endif
}
template
<
typename
T
,
typename
IdxT
,
ScoringFunc
SF
>
inline
void
launch_group_idx_and_topk_kernel
(
cudaLaunchConfig_t
const
&
config
,
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_tokens
,
int64_t
const
n_group
,
int64_t
const
topk_group
,
int64_t
const
topk
,
int64_t
const
num_experts
,
int64_t
const
num_experts_per_group
,
bool
const
renormalize
,
double
const
routed_scaling_factor
)
{
auto
launch
=
[
&
](
auto
*
kernel_instance2
)
{
cudaLaunchKernelEx
(
&
config
,
kernel_instance2
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
};
switch
(
n_group
)
{
case
4
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
4
>
);
break
;
}
case
8
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
8
>
);
break
;
}
case
16
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
16
>
);
break
;
}
case
32
:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
,
32
>
);
break
;
}
default:
{
launch
(
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
,
SF
>
);
break
;
}
}
}
template
<
typename
T
,
typename
IdxT
>
void
invokeNoAuxTc
(
T
*
scores
,
T
*
group_scores
,
float
*
topk_values
,
IdxT
*
topk_indices
,
T
const
*
bias
,
int64_t
const
num_tokens
,
...
...
@@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
cudaStream_t
const
stream
=
0
)
{
int64_t
num_cases
=
num_tokens
*
n_group
;
int64_t
topk_with_k2_num_blocks
=
(
num_cases
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
>
;
cudaLaunchConfig_t
config
;
config
.
gridDim
=
topk_with_k2_num_blocks
;
config
.
blockDim
=
BLOCK_SIZE
;
...
...
@@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
enable_pdl
;
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
auto
const
sf
=
static_cast
<
ScoringFunc
>
(
scoring_func
);
int64_t
const
num_experts_per_group
=
num_experts
/
n_group
;
auto
launch_topk_with_k2
=
[
&
](
auto
*
kernel_instance1
)
{
cudaLaunchKernelEx
(
&
config
,
kernel_instance1
,
group_scores
,
scores
,
bias
,
num_tokens
,
num_cases
,
n_group
,
num_experts
/
n_group
,
scoring_func
);
num_tokens
,
num_cases
,
n_group
,
num_experts_per_group
);
};
switch
(
sf
)
{
case
SCORING_NONE
:
{
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
SCORING_NONE
>
;
launch_topk_with_k2
(
kernel_instance1
);
break
;
}
case
SCORING_SIGMOID
:
{
auto
*
kernel_instance1
=
&
topk_with_k2_kernel
<
T
,
SCORING_SIGMOID
>
;
launch_topk_with_k2
(
kernel_instance1
);
break
;
}
default:
// should be guarded by higher level checks.
TORCH_CHECK
(
false
,
"Unsupported scoring_func in invokeNoAuxTc"
);
}
int64_t
topk_with_k_group_num_blocks
=
(
num_tokens
-
1
)
/
NUM_WARPS_PER_BLOCK
+
1
;
size_t
dynamic_smem_in_bytes
=
warp_topk
::
calc_smem_size_for_block_wide
<
T
,
int32_t
>
(
NUM_WARPS_PER_BLOCK
,
topk
);
auto
*
kernel_instance2
=
&
group_idx_and_topk_idx_kernel
<
T
,
IdxT
>
;
config
.
gridDim
=
topk_with_k_group_num_blocks
;
config
.
blockDim
=
BLOCK_SIZE
;
config
.
dynamicSmemBytes
=
dynamic_smem_in_bytes
;
...
...
@@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
enable_pdl
;
config
.
numAttrs
=
1
;
config
.
attrs
=
attrs
;
cudaLaunchKernelEx
(
&
config
,
kernel_instance2
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts
/
n_group
,
renormalize
,
routed_scaling_factor
,
scoring_func
);
switch
(
sf
)
{
case
SCORING_NONE
:
{
launch_group_idx_and_topk_kernel
<
T
,
IdxT
,
SCORING_NONE
>
(
config
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
break
;
}
case
SCORING_SIGMOID
:
{
launch_group_idx_and_topk_kernel
<
T
,
IdxT
,
SCORING_SIGMOID
>
(
config
,
scores
,
group_scores
,
topk_values
,
topk_indices
,
bias
,
num_tokens
,
n_group
,
topk_group
,
topk
,
num_experts
,
num_experts_per_group
,
renormalize
,
routed_scaling_factor
);
break
;
}
default:
TORCH_CHECK
(
false
,
"Unsupported scoring_func in invokeNoAuxTc"
);
}
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
8d75f22e
This diff is collapsed.
Click to expand it.
csrc/moe/moe_lora_align_sum_kernels.cu
deleted
100644 → 0
View file @
ce888aa4
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
return
row
*
total_col
+
col
;
}
}
// namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template
<
typename
scalar_t
,
typename
token_cnts_t
>
__global__
void
moe_lora_align_sum_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
token_lora_mapping
,
int64_t
block_size
,
int
num_experts
,
int
max_loras
,
size_t
numel
,
int
max_num_tokens_padded
,
int
max_num_m_blocks
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int
topk_num
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
adapter_enabled
,
int32_t
*
lora_ids
)
{
const
size_t
tokens_per_thread
=
div_ceil
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
int
lora_idx
=
blockIdx
.
x
;
int
lora_id
=
lora_ids
[
lora_idx
];
if
(
lora_id
==
-
1
||
adapter_enabled
[
lora_id
]
==
0
)
{
return
;
}
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
cumsum
=
shared_mem
;
token_cnts_t
*
tokens_cnts
=
(
token_cnts_t
*
)(
shared_mem
+
num_experts
+
1
);
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
lora_id
*
max_num_tokens_padded
+
it
]
=
numel
;
}
// Initialize expert_ids with -1
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_m_blocks
;
it
+=
blockDim
.
x
)
{
expert_ids
[
lora_id
*
max_num_m_blocks
+
it
]
=
-
1
;
}
// Initialize total_tokens_post_pad with 0
if
(
threadIdx
.
x
==
0
)
{
total_tokens_post_pad
[
lora_id
]
=
0
;
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int
mask
=
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
int
idx
=
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
]);
tokens_cnts
[
idx
]
+=
mask
;
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
div_ceil
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
total_tokens_post_pad
[
lora_id
]
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
index
(
max_num_m_blocks
,
lora_id
,
i
/
block_size
)]
=
threadIdx
.
x
;
}
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
int
mask
=
(
int
)
token_lora_mapping
[
i
/
topk_num
]
==
lora_id
;
atomicAdd
(
&
sorted_token_ids
[
index
(
max_num_tokens_padded
,
lora_id
,
rank_post_pad
)],
(
i
-
numel
)
*
mask
);
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+=
mask
;
}
}
void
moe_lora_align_block_size
(
torch
::
Tensor
topk_ids
,
torch
::
Tensor
token_lora_mapping
,
int64_t
num_experts
,
int64_t
block_size
,
int64_t
max_loras
,
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
)
{
const
int
topk_num
=
topk_ids
.
size
(
1
);
TORCH_CHECK
(
block_size
>
0
,
"block_size should be greater than 0. "
);
int
device_max_shared_mem
;
auto
dev
=
topk_ids
.
get_device
();
cudaDeviceGetAttribute
(
&
device_max_shared_mem
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int32_t
num_thread
=
max
((
int32_t
)
num_experts
,
128
);
// WARP_SIZE,
TORCH_CHECK
(
num_thread
<=
1024
,
"num_thread must be less than 1024, "
"and fallback is not implemented yet."
);
const
int32_t
shared_mem
=
(
num_thread
+
1
)
*
num_experts
*
sizeof
(
int32_t
)
+
(
num_experts
+
1
)
*
sizeof
(
int32_t
);
if
(
shared_mem
>
device_max_shared_mem
)
{
TORCH_CHECK
(
false
,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet."
);
}
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_lora_align_sum_kernel"
,
[
&
]
{
dim3
blockDim
(
num_thread
);
auto
kernel
=
moe_lora_align_sum_kernel
<
scalar_t
,
int32_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
max_loras
,
blockDim
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
token_lora_mapping
.
data_ptr
<
int32_t
>
(),
block_size
,
num_experts
,
max_loras
,
topk_ids
.
numel
(),
max_num_tokens_padded
,
max_num_m_blocks
,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
expert_ids
.
data_ptr
<
int32_t
>
(),
topk_num
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
adapter_enabled
.
data_ptr
<
int32_t
>
(),
lora_ids
.
data_ptr
<
int32_t
>
());
});
}
\ No newline at end of file
csrc/moe/moe_ops.h
View file @
8d75f22e
...
...
@@ -11,7 +11,8 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
torch
::
Tensor
num_tokens_post_pad
,
std
::
optional
<
torch
::
Tensor
>
maybe_expert_map
);
void
batched_moe_align_block_size
(
int64_t
max_tokens_per_batch
,
int64_t
block_size
,
...
...
@@ -26,7 +27,7 @@ void moe_lora_align_block_size(
int64_t
max_num_tokens_padded
,
int64_t
max_num_m_blocks
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
adapter_enabled
,
torch
::
Tensor
lora_ids
);
torch
::
Tensor
lora_ids
,
std
::
optional
<
torch
::
Tensor
>
maybe_expert_map
);
#ifndef USE_ROCM
torch
::
Tensor
moe_wna16_gemm
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
b_qweight
,
torch
::
Tensor
b_scales
,
...
...
csrc/moe/torch_bindings.cpp
View file @
8d75f22e
...
...
@@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"
);
" Tensor! num_tokens_post_pad,"
" Tensor? maybe_expert_map) -> ()"
);
m
.
impl
(
"moe_align_block_size"
,
torch
::
kCUDA
,
&
moe_align_block_size
);
// Aligning the number of tokens to be processed by each expert such
...
...
@@ -46,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () "
);
" Tensor !lora_ids,"
" Tensor? maybe_expert_map) -> () "
);
m
.
impl
(
"moe_lora_align_block_size"
,
torch
::
kCUDA
,
&
moe_lora_align_block_size
);
#ifndef USE_ROCM
...
...
csrc/ops.h
View file @
8d75f22e
...
...
@@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
repetition_penalties
);
void
top_k_per_row
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
rowStarts
,
void
top_k_per_row_prefill
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
rowStarts
,
const
torch
::
Tensor
&
rowEnds
,
torch
::
Tensor
&
indices
,
int64_t
numRows
,
int64_t
stride0
,
int64_t
stride1
);
int64_t
numRows
,
int64_t
stride0
,
int64_t
stride1
,
int64_t
topK
);
void
top_k_per_row_decode
(
const
torch
::
Tensor
&
logits
,
int64_t
next_n
,
const
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
indices
,
int64_t
numRows
,
int64_t
stride0
,
int64_t
stride1
);
const
torch
::
Tensor
&
seqLens
,
torch
::
Tensor
&
indices
,
int64_t
numRows
,
int64_t
stride0
,
int64_t
stride1
,
int64_t
topK
);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
...
...
@@ -128,6 +131,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std
::
optional
<
torch
::
Tensor
>
scale_ub
,
std
::
optional
<
torch
::
Tensor
>
residual
);
void
rms_norm_per_block_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
weight
,
torch
::
Tensor
&
scales
,
double
const
epsilon
,
std
::
optional
<
torch
::
Tensor
>
scale_ub
,
std
::
optional
<
torch
::
Tensor
>
residual
,
int64_t
group_size
,
bool
is_scale_transposed
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
...
...
@@ -254,7 +264,8 @@ void get_cutlass_moe_mm_data(
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
);
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
...
...
@@ -301,6 +312,14 @@ void per_token_group_quant_int8(const torch::Tensor& input,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
void
per_token_group_quant_8bit_packed
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s_packed
,
int64_t
group_size
,
double
eps
,
double
min_8bit
,
double
max_8bit
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/quantization/cutlass_w4a8/get_group_starts.cuh
0 → 100644
View file @
8d75f22e
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
// ElementB is int32 (packed int4)
// ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
template
<
typename
ElementA
,
typename
ElementB
,
typename
ElementC
,
typename
ElementAccumulator
,
typename
ElementGroupScale
>
__global__
void
get_group_gemm_starts
(
int64_t
*
expert_offsets
,
ElementA
**
a_offsets
,
ElementB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementAccumulator
**
a_scales_offsets
,
ElementAccumulator
**
b_scales_offsets
,
ElementGroupScale
**
b_group_scales_offsets
,
ElementA
*
a_base_as_int
,
ElementB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementAccumulator
*
a_scales_base_as_int
,
ElementAccumulator
*
b_scales_base_as_int
,
ElementGroupScale
*
b_group_scales_base_as_int
,
int64_t
n
,
int64_t
k
,
int64_t
scale_k
)
{
int
expert_id
=
threadIdx
.
x
;
int64_t
expert_offset
=
expert_offsets
[
expert_id
];
// same as w8a8
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
k
;
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
expert_offset
;
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
(
n
*
expert_id
);
// w4a8 specific
constexpr
int
pack_factor
=
8
;
// pack 8 int4 into int32
b_offsets
[
expert_id
]
=
b_base_as_int
+
(
expert_id
*
k
*
n
/
pack_factor
);
b_group_scales_offsets
[
expert_id
]
=
b_group_scales_base_as_int
+
(
expert_id
*
scale_k
*
n
);
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int64_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<int32_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>**>( \
b_group_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<int32_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( \
b_group_scales.data_ptr()), \
n, k, scale_k); \
}
namespace
{
void
run_get_group_gemm_starts
(
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
&
a_ptrs
,
torch
::
Tensor
&
b_ptrs
,
torch
::
Tensor
&
out_ptrs
,
torch
::
Tensor
&
a_scales_ptrs
,
torch
::
Tensor
&
b_scales_ptrs
,
torch
::
Tensor
&
b_group_scales_ptrs
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_group_scales
,
const
int64_t
b_group_size
)
{
TORCH_CHECK
(
a_tensors
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b_tensors
.
dtype
()
==
torch
::
kInt32
);
// int4 8x packed into int32
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_group_scales
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
// the underlying torch type is e4m3
TORCH_CHECK
(
out_tensors
.
dtype
()
==
torch
::
kBFloat16
);
// only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK
(
expert_offsets
.
dtype
()
==
torch
::
kInt64
);
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
// logical k, n
int64_t
n
=
out_tensors
.
size
(
1
);
int64_t
k
=
a_tensors
.
size
(
1
);
int64_t
scale_k
=
cutlass
::
ceil_div
(
k
,
b_group_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
if
(
false
)
{
}
__CALL_GET_STARTS_KERNEL
(
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
)
__CALL_GET_STARTS_KERNEL
(
torch
::
kFloat16
,
half
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
}
// namespace
\ No newline at end of file
csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu
0 → 100644
View file @
8d75f22e
This diff is collapsed.
Click to expand it.
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
View file @
8d75f22e
...
...
@@ -7,6 +7,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h"
...
...
@@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return
packed_scales
;
}
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__
uint8_t
kNibbleLUT
[
256
];
__global__
void
unified_encode_int4b_device
(
const
uint8_t
*
in
,
uint8_t
*
out
,
size_t
nbytes
)
{
constexpr
size_t
V
=
sizeof
(
uint4
);
// 16 bytes
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
nthreads
=
size_t
(
gridDim
.
x
)
*
blockDim
.
x
;
const
size_t
nvec
=
nbytes
/
V
;
// 1-D grid-stride loop over 16-byte chunks
for
(
size_t
vec
=
tid
;
vec
<
nvec
;
vec
+=
nthreads
)
{
uint4
v
=
reinterpret_cast
<
const
uint4
*>
(
in
)[
vec
];
uint8_t
*
b
=
reinterpret_cast
<
uint8_t
*>
(
&
v
);
#pragma unroll
for
(
int
i
=
0
;
i
<
int
(
V
);
++
i
)
b
[
i
]
=
kNibbleLUT
[
b
[
i
]];
reinterpret_cast
<
uint4
*>
(
out
)[
vec
]
=
v
;
}
}
static
bool
upload_lut
()
{
std
::
array
<
uint8_t
,
256
>
lut
{};
auto
map_nib
=
[](
uint8_t
v
)
->
uint8_t
{
// 1..7 -> (8 - v); keep 0 and 8..15
return
(
v
==
0
||
(
v
&
0x8
))
?
v
:
uint8_t
(
8
-
v
);
};
for
(
int
b
=
0
;
b
<
256
;
++
b
)
{
uint8_t
lo
=
b
&
0xF
;
uint8_t
hi
=
(
b
>>
4
)
&
0xF
;
lut
[
b
]
=
uint8_t
((
map_nib
(
hi
)
<<
4
)
|
map_nib
(
lo
));
}
cudaError_t
e
=
cudaMemcpyToSymbol
(
kNibbleLUT
,
lut
.
data
(),
lut
.
size
(),
/*offset=*/
0
,
cudaMemcpyHostToDevice
);
return
(
e
==
cudaSuccess
);
}
static
bool
unified_encode_int4b
(
cutlass
::
int4b_t
const
*
in
,
cutlass
::
int4b_t
*
out
,
size_t
num_int4_elems
)
{
// Build/upload LUT
if
(
!
upload_lut
())
return
false
;
static_assert
(
sizeof
(
typename
cutlass
::
int4b_t
::
Storage
)
==
1
,
"int4 storage must be 1 byte"
);
const
size_t
nbytes
=
num_int4_elems
>>
1
;
auto
*
in_bytes
=
reinterpret_cast
<
uint8_t
const
*>
(
in
);
auto
*
out_bytes
=
reinterpret_cast
<
uint8_t
*>
(
out
);
// kernel launch params
constexpr
int
block
=
256
;
const
size_t
nvec
=
nbytes
/
sizeof
(
uint4
);
// # of 16B vectors
int
grid
=
int
((
nvec
+
block
-
1
)
/
block
);
if
(
grid
==
0
)
grid
=
1
;
// ensure we still cover the tail in the kernel
unified_encode_int4b_device
<<<
grid
,
block
>>>
(
in_bytes
,
out_bytes
,
nbytes
);
cudaError_t
err
=
cudaGetLastError
();
return
(
err
==
cudaSuccess
);
}
torch
::
Tensor
encode_and_reorder_int4b
(
torch
::
Tensor
const
&
B
)
{
TORCH_CHECK
(
B
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
B
.
dim
()
==
2
);
...
...
@@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
LayoutB_Reordered
layout_B_reordered
=
cute
::
tile_to_shape
(
LayoutAtomQuant
{},
shape_B
);
bool
ok
=
vllm
::
cutlass_w4a8
::
unified_encode_int4b
(
B_ptr
,
B_packed_ptr
,
n
*
k
);
bool
ok
=
vllm
::
cutlass_w4a8_utils
::
unified_encode_int4b
(
B_ptr
,
B_packed_ptr
,
n
*
k
);
TORCH_CHECK
(
ok
,
"unified_encode_int4b failed"
);
cutlass
::
reorder_tensor
(
B_packed_ptr
,
layout_B
,
layout_B_reordered
);
...
...
csrc/quantization/cutlass_w4a8/w4a8_utils.cu
0 → 100644
View file @
8d75f22e
#include "w4a8_utils.cuh"
#include <array>
#include <cuda_runtime.h>
#include <cstdio>
namespace
vllm
::
cutlass_w4a8_utils
{
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__
uint8_t
kNibbleLUT
[
256
];
__global__
void
unified_encode_int4b_device
(
const
uint8_t
*
in
,
uint8_t
*
out
,
size_t
nbytes
)
{
constexpr
size_t
V
=
sizeof
(
uint4
);
// 16 bytes
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
nthreads
=
size_t
(
gridDim
.
x
)
*
blockDim
.
x
;
const
size_t
nvec
=
nbytes
/
V
;
// 1-D grid-stride loop over 16-byte chunks
for
(
size_t
vec
=
tid
;
vec
<
nvec
;
vec
+=
nthreads
)
{
uint4
v
=
reinterpret_cast
<
const
uint4
*>
(
in
)[
vec
];
uint8_t
*
b
=
reinterpret_cast
<
uint8_t
*>
(
&
v
);
#pragma unroll
for
(
int
i
=
0
;
i
<
int
(
V
);
++
i
)
b
[
i
]
=
kNibbleLUT
[
b
[
i
]];
reinterpret_cast
<
uint4
*>
(
out
)[
vec
]
=
v
;
}
}
static
bool
upload_lut
()
{
std
::
array
<
uint8_t
,
256
>
lut
{};
auto
map_nib
=
[](
uint8_t
v
)
->
uint8_t
{
// 1..7 -> (8 - v); keep 0 and 8..15
return
(
v
==
0
||
(
v
&
0x8
))
?
v
:
uint8_t
(
8
-
v
);
};
for
(
int
b
=
0
;
b
<
256
;
++
b
)
{
uint8_t
lo
=
b
&
0xF
;
uint8_t
hi
=
(
b
>>
4
)
&
0xF
;
lut
[
b
]
=
uint8_t
((
map_nib
(
hi
)
<<
4
)
|
map_nib
(
lo
));
}
cudaError_t
e
=
cudaMemcpyToSymbol
(
kNibbleLUT
,
lut
.
data
(),
lut
.
size
(),
/*offset=*/
0
,
cudaMemcpyHostToDevice
);
return
(
e
==
cudaSuccess
);
}
bool
unified_encode_int4b
(
cutlass
::
int4b_t
const
*
in
,
cutlass
::
int4b_t
*
out
,
size_t
num_int4_elems
)
{
// Build/upload LUT
if
(
!
upload_lut
())
return
false
;
static_assert
(
sizeof
(
typename
cutlass
::
int4b_t
::
Storage
)
==
1
,
"int4 storage must be 1 byte"
);
const
size_t
nbytes
=
num_int4_elems
>>
1
;
auto
*
in_bytes
=
reinterpret_cast
<
uint8_t
const
*>
(
in
);
auto
*
out_bytes
=
reinterpret_cast
<
uint8_t
*>
(
out
);
// kernel launch params
constexpr
int
block
=
256
;
const
size_t
nvec
=
nbytes
/
sizeof
(
uint4
);
// # of 16B vectors
int
grid
=
int
((
nvec
+
block
-
1
)
/
block
);
if
(
grid
==
0
)
grid
=
1
;
// ensure we still cover the tail in the kernel
unified_encode_int4b_device
<<<
grid
,
block
>>>
(
in_bytes
,
out_bytes
,
nbytes
);
// launch errors
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"unified_encode_int4b_device launch error: %s (%d)
\n
"
,
cudaGetErrorString
(
err
),
err
);
return
false
;
}
// runtime errors
err
=
cudaDeviceSynchronize
();
if
(
err
!=
cudaSuccess
)
{
printf
(
"unified_encode_int4b_device runtime error: %s (%d)
\n
"
,
cudaGetErrorString
(
err
),
err
);
return
false
;
}
return
true
;
}
}
// namespace vllm::cutlass_w4a8_utils
\ No newline at end of file
csrc/quantization/cutlass_w4a8/w4a8_utils.cuh
0 → 100644
View file @
8d75f22e
#pragma once
#include <cstddef>
#include "cutlass/numeric_types.h"
namespace
vllm
::
cutlass_w4a8_utils
{
bool
unified_encode_int4b
(
cutlass
::
int4b_t
const
*
in
,
cutlass
::
int4b_t
*
out
,
size_t
num_int4_elems
);
}
// namespace vllm::cutlass_w4a8_utils
\ No newline at end of file
csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu
View file @
8d75f22e
This diff is collapsed.
Click to expand it.
csrc/quantization/fused_kernels/layernorm_utils.cuh
View file @
8d75f22e
This diff is collapsed.
Click to expand it.
csrc/quantization/w8a8/cutlass/moe/moe_data.cu
View file @
8d75f22e
...
...
@@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
Tensor
atomic_buffer
=
torch
::
zeros
(
num_experts
,
options_int32
);
// Swap-AB should be disabled for FP4 path
bool
may_swap_ab
=
(
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
);
bool
may_swap_ab
=
force_swap_ab
.
value_or
((
!
blockscale_offsets
.
has_value
())
&&
(
topk_ids
.
numel
()
<=
SWAP_AB_THRESHOLD
));
launch_compute_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
atomic_buffer
,
num_experts
,
n
,
k
,
stream
,
...
...
csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
View file @
8d75f22e
...
...
@@ -80,7 +80,8 @@ void get_cutlass_moe_mm_data_caller(
void
get_cutlass_moe_mm_problem_sizes_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
);
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
...
...
@@ -303,14 +304,15 @@ void get_cutlass_moe_mm_data(
void
get_cutlass_moe_mm_problem_sizes
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
,
std
::
optional
<
bool
>
force_swap_ab
=
std
::
nullopt
)
{
int32_t
version_num
=
get_sm_version_num
();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_caller
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
);
blockscale_offsets
,
force_swap_ab
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
View file @
8d75f22e
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
…
36
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment