Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c676e3d
Commit
4c676e3d
authored
Jun 20, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.1' into v0.9.1-dev
parents
b4c4464d
b6553be1
Changes
418
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1721 additions
and
179 deletions
+1721
-179
csrc/ops.h
csrc/ops.h
+65
-6
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+3
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+80
-52
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+121
-0
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+14
-2
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
...ization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
+23
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
+279
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
+75
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+5
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
+51
-2
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
+19
-10
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
+2
-4
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
+69
-9
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
+5
-17
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
+5
-46
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+57
-14
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+402
-0
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+404
-0
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+23
-1
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+19
-16
No files found.
Too many changes to show.
To preserve performance only
418 of 418+
files are displayed.
Plain diff
Email patch
csrc/ops.h
View file @
4c676e3d
...
...
@@ -179,6 +179,31 @@ void merge_attn_states(torch::Tensor& output,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
);
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
);
void
convert_vertical_slash_indexes_mergehead
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
torch
::
Tensor
vertical_indices_count
,
// [N_HEADS, ]
torch
::
Tensor
slash_indices_count
,
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
);
#endif
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
...
...
@@ -193,6 +218,11 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
repetition_penalties
);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
...
...
@@ -212,13 +242,13 @@ void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
// std::optional<torch::Tensor> residual);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
rot_dim
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
rotary_embedding_tgi
(
torch
::
Tensor
&
query
,
...
...
@@ -230,6 +260,9 @@ void rotary_embedding_tgi(
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
@@ -316,6 +349,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
torch
::
Tensor
ggml_moe_a8_vec
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
topk_ids
,
int64_t
top_k
,
int64_t
type
,
int64_t
row
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
#ifndef USE_ROCM
...
...
@@ -340,13 +377,29 @@ void cutlass_moe_mm(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
);
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
);
void
get_cutlass_moe_mm_data
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
...
...
@@ -369,6 +422,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/opt/activation_kernels_opt.cu
View file @
4c676e3d
...
...
@@ -161,6 +161,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
}
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
\
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
\
VLLM_DISPATCH_FLOATING_TYPES
(
\
...
...
csrc/pos_encoding_kernels.cu
View file @
4c676e3d
...
...
@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int64_t
head_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
...
@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding(
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_stride
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
if
(
key
!=
nullptr
)
{
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_stride
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
}
...
...
@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
...
...
@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
...
...
@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
// or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
...
...
@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
// namespace vllm
...
...
@@ -127,10 +136,12 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std
::
optional
<
torch
::
Tensor
>
key
,
// null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
...
...
@@ -138,40 +149,46 @@ void rotary_embedding(
int64_t
num_tokens
=
positions
.
numel
();
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
"query, key and positions must have the same number of tokens"
);
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
}
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have consistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
@@ -181,15 +198,16 @@ void rotary_embedding(
if
(
is_neox
)
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
...
...
@@ -204,10 +222,12 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std
::
optional
<
torch
::
Tensor
>
key
,
// null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int64_t
rot_dim
,
...
...
@@ -221,38 +241,44 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets"
);
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
"query, key and positions must have the same number of tokens"
);
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
}
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
}
// Make sure head_size is valid for query and key
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have concistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
@@ -263,16 +289,18 @@ void batched_rotary_embedding(
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
csrc/quantization/activation_kernels.cu
0 → 100644
View file @
4c676e3d
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
typename
fp8_type
>
__global__
void
act_and_mul_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
float
*
scale
,
const
int
d
)
{
const
int32_t
blocks_per_token
=
gridDim
.
y
;
const
int32_t
elems_per_128bit_load
=
(
128
/
8
)
/
sizeof
(
scalar_t
);
// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const
int32_t
tgt_elems_per_block
=
div_ceil
(
d
,
blocks_per_token
);
const
int32_t
elems_per_block
=
round_to_next_multiple_of
(
tgt_elems_per_block
,
elems_per_128bit_load
);
const
int32_t
block_start
=
blockIdx
.
y
*
elems_per_block
;
int32_t
block_end
=
block_start
+
elems_per_block
;
block_end
=
block_end
>
d
?
d
:
block_end
;
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const
int64_t
token_idx
=
blockIdx
.
x
;
const
scalar_t
*
__restrict__
x_ptr
=
input
+
token_idx
*
2
*
d
;
const
scalar_t
*
__restrict__
y_ptr
=
input
+
token_idx
*
2
*
d
+
d
;
fp8_type
*
__restrict__
out_ptr
=
out
+
token_idx
*
d
;
// 128-bit vectorized code
const
int32_t
vec_loop_end
=
round_to_previous_multiple_of
(
elems_per_128bit_load
,
block_end
);
const
int32_t
vec_end_idx
=
vec_loop_end
/
elems_per_128bit_load
;
const
int32_t
vec_start_idx
=
block_start
/
elems_per_128bit_load
;
const
int4
*
__restrict__
x_128bit_ptr
=
reinterpret_cast
<
const
int4
*>
(
x_ptr
);
const
int4
*
__restrict__
y_128bit_ptr
=
reinterpret_cast
<
const
int4
*>
(
y_ptr
);
int2
*
__restrict__
out_128bit_ptr
=
reinterpret_cast
<
int2
*>
(
out_ptr
);
float
inverted_scale
=
1
/
*
scale
;
#pragma unroll
for
(
int32_t
vec_idx
=
vec_start_idx
+
threadIdx
.
x
;
vec_idx
<
vec_end_idx
;
vec_idx
+=
blockDim
.
x
)
{
const
int4
x_128bit
=
VLLM_LDG
(
&
x_128bit_ptr
[
vec_idx
]);
const
int4
y_128bit
=
VLLM_LDG
(
&
y_128bit_ptr
[
vec_idx
]);
using
scalar_128bit_vec_t
=
std
::
array
<
scalar_t
,
elems_per_128bit_load
>
;
using
scalar_64bit_vec_t
=
std
::
array
<
fp8_type
,
elems_per_128bit_load
>
;
scalar_64bit_vec_t
out_vec
;
const
auto
x_vec
=
reinterpret_cast
<
scalar_128bit_vec_t
const
&>
(
x_128bit
);
const
auto
y_vec
=
reinterpret_cast
<
scalar_128bit_vec_t
const
&>
(
y_128bit
);
#pragma unroll
for
(
int
i
=
0
;
i
<
elems_per_128bit_load
;
i
++
)
{
out_vec
[
i
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
ACT_FN
(
x_vec
[
i
])
*
y_vec
[
i
],
inverted_scale
);
}
out_128bit_ptr
[
vec_idx
]
=
reinterpret_cast
<
const
int2
&>
(
out_vec
);
}
// Scalar cleanup code
if
(
block_end
>
vec_loop_end
)
{
for
(
int64_t
idx
=
vec_loop_end
+
threadIdx
.
x
;
idx
<
block_end
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
x_ptr
[
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
y_ptr
[
idx
]);
out_ptr
[
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
ACT_FN
(
x
)
*
y
,
inverted_scale
);
}
}
}
}
// namespace vllm
// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
fp8_t> \
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), d); \
}); \
});
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., 2 * d]
torch
::
Tensor
&
scale
)
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
out
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
2
==
0
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
4c676e3d
...
...
@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float
dst
=
std
::
nearbyint
(
x
);
// saturate
dst
=
std
::
clamp
(
dst
,
i8_min
,
i8_max
);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
dst
=
(
dst
<
i8_min
)
?
i8_min
:
(
dst
>
i8_max
)
?
i8_max
:
dst
;
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
...
...
@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// saturate
int32_t
dst
=
std
::
clamp
(
x
,
i8_min
,
i8_max
);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// int32_t dst = std::clamp(x, i8_min, i8_max);
int32_t
dst
=
(
x
<
i8_min
)
?
i8_min
:
(
x
>
i8_max
)
?
i8_max
:
x
;
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
0 → 100644
View file @
4c676e3d
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
0 → 100644
View file @
4c676e3d
#pragma once
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
// clang-format off
template
<
class
OutType
,
int
ScaleGranularityM
,
int
ScaleGranularityN
,
int
ScaleGranularityK
,
class
MmaTileShape
,
class
ClusterShape
,
class
EpilogueScheduler
,
class
MainloopScheduler
,
bool
swap_ab_
=
false
>
struct
cutlass_3x_gemm_fp8_blockwise
{
static
constexpr
bool
swap_ab
=
swap_ab_
;
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutA_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutA
>::
type
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutB_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutB
>::
type
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutD
>::
type
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
// TODO: support bias
using
LayoutC
=
LayoutD
;
using
LayoutC_Transpose
=
LayoutD_Transpose
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ScaleConfig
=
conditional_t
<
swap_ab
,
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
MN
>
,
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>>
;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
ElementScalar
=
float
;
using
DefaultOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementCompute
,
ElementC
,
ElementScalar
,
RoundStyle
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
MmaTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
conditional_t
<
swap_ab
,
LayoutC_Transpose
,
LayoutC
>
,
AlignmentC
,
ElementD
,
conditional_t
<
swap_ab
,
LayoutD_Transpose
,
LayoutD
>
,
AlignmentD
,
EpilogueScheduler
,
DefaultOperation
>::
CollectiveOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
CollectiveMainloop
=
conditional_t
<
swap_ab
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementB
,
cute
::
tuple
<
LayoutB_Transpose
,
LayoutSFA
>
,
AlignmentB
,
ElementA
,
cute
::
tuple
<
LayoutA_Transpose
,
LayoutSFB
>
,
AlignmentA
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
>
;
using
KernelType
=
enable_sm100_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
static
constexpr
bool
swap_ab
=
Gemm
::
swap_ab
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutSFA
=
typename
Gemm
::
LayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
LayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
ScaleConfig
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
StrideA
a_stride
;
StrideB
b_stride
;
StrideC
c_stride
;
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
swap_ab
?
cute
::
make_shape
(
n
,
m
,
1
)
:
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
auto
mainloop_args
=
[
&
](){
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
if
(
swap_ab
)
{
return
typename
GemmKernel
::
MainloopArguments
{
b_ptr
,
b_stride
,
a_ptr
,
a_stride
,
b_scales_ptr
,
layout_SFA
,
a_scales_ptr
,
layout_SFB
};
}
else
{
return
typename
GemmKernel
::
MainloopArguments
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
}
}();
auto
prob_shape
=
swap_ab
?
cute
::
make_shape
(
n
,
m
,
k
,
1
)
:
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm100_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
),
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
a
.
get_device
());
constexpr
int
TILE_K
=
128
;
// TODO: better heuristics
bool
swap_ab
=
(
m
<
16
)
||
(
m
%
4
!=
0
);
bool
use_tma_epilogue
=
(
m
*
n
)
%
4
==
0
;
if
(
!
swap_ab
)
{
constexpr
int
TILE_N
=
128
;
int
tile_m
=
256
;
if
(
cuda_utils
::
ceil_div
(
n
,
TILE_N
)
*
cuda_utils
::
ceil_div
(
m
,
64
)
<=
sms
)
{
tile_m
=
64
;
}
else
if
(
cuda_utils
::
ceil_div
(
n
,
TILE_N
)
*
cuda_utils
::
ceil_div
(
m
,
128
)
<=
sms
)
{
tile_m
=
128
;
}
if
(
tile_m
==
64
)
{
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_64
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_64
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
if
(
tile_m
==
128
)
{
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_128
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_128
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
{
// tile_m == 256
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_256
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_256
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
else
{
// TODO: Test more tile N configs
constexpr
int
TILE_M
=
128
;
constexpr
int
TILE_N
=
16
;
// TMA epilogue isn't compatible with Swap A/B
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
TILE_M
,
1
,
TILE_K
,
Shape
<
Int
<
TILE_M
>
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
,
true
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
0 → 100644
View file @
4c676e3d
#include <torch/all.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template
<
typename
Fp8Func
,
typename
Int8Func
,
typename
BlockwiseFunc
>
void
dispatch_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
,
Fp8Func
fp8_func
,
Int8Func
int8_func
,
BlockwiseFunc
blockwise_func
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
fp8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
if
constexpr
(
!
std
::
is_same_v
<
Int8Func
,
std
::
nullptr_t
>
)
{
int8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
false
,
"Int8 not supported for this architecture"
);
}
}
}
else
{
TORCH_CHECK
(
a_scales
.
dim
()
==
2
,
"a scale must be 2d tensor."
);
TORCH_CHECK
(
b_scales
.
dim
()
==
2
,
"b scale must be 2d tensor."
);
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
100
)
{
TORCH_CHECK
(
a
.
size
(
0
)
==
a_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
a
.
size
(
1
),
int64_t
(
128
))
==
a_scales
.
size
(
1
),
"a_scale_group_shape must be [1, 128]."
);
TORCH_CHECK
(
cuda_utils
::
ceil_div
(
b
.
size
(
0
),
int64_t
(
128
))
==
b_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
b
.
size
(
1
),
int64_t
(
128
))
==
b_scales
.
size
(
1
),
"b_scale_group_shape must be [128, 128]."
);
}
else
{
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
}
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
blockwise_func
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
View file @
4c676e3d
...
...
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
View file @
4c676e3d
...
...
@@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
...
...
@@ -25,6 +26,34 @@ struct sm100_fp8_config_default {
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
...
...
@@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
using
Cutlass3xGemmDefault
=
typename
sm100_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
using
Cutlass3xGemmM64
=
typename
sm100_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm100_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
...
...
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
View file @
4c676e3d
...
...
@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
TORCH_CHECK
(
a_tensors
.
size
(
0
)
>
0
,
"No input A tensors provided."
);
TORCH_CHECK
(
b_tensors
.
size
(
0
)
>
0
,
"No input B tensors provided."
);
TORCH_CHECK
(
out_tensors
.
size
(
0
)
>
0
,
"No output tensors provided."
);
...
...
@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if
(
n
>=
8192
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmN8192
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
if
(
k
>=
8192
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmK8192
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
if
(
m
<=
16
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmM16
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
{
cutlass_group_gemm_caller
<
Cutlass3xGemmDefault
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
}
...
...
@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
if
(
out_tensors
.
dtype
()
==
torch
::
kBFloat16
)
{
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
{
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
}
...
...
@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
dispatch_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
c_strides
,
per_act_token
,
per_out_ch
);
}
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
View file @
4c676e3d
...
...
@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
...
...
@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
int
k_size
=
a_tensors
.
size
(
1
);
int
n_size
=
out_tensors
.
size
(
1
);
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
bool
per_out_ch
=
b_scales
.
numel
()
!=
num_experts
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
auto
options_int
=
...
...
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
View file @
4c676e3d
...
...
@@ -7,7 +7,7 @@
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
__global__
void
compute_problem_sizes
(
const
int
*
__restrict__
topk_ids
,
__global__
void
compute_problem_sizes
(
const
u
int
32_t
*
__restrict__
topk_ids
,
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes2
,
int32_t
*
atomic_buffer
,
...
...
@@ -45,7 +45,24 @@ __global__ void compute_expert_offsets(
}
}
__global__
void
compute_arg_sorts
(
const
int
*
__restrict__
topk_ids
,
__global__
void
compute_expert_blockscale_offsets
(
const
int32_t
*
__restrict__
problem_sizes1
,
int32_t
*
expert_offsets
,
int32_t
*
blockscale_offsets
,
int32_t
*
atomic_buffer
,
const
int
num_experts
)
{
int32_t
tot_offset
=
0
;
int32_t
tot_offset_round
=
0
;
expert_offsets
[
0
]
=
0
;
blockscale_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
atomic_buffer
[
i
]
=
tot_offset
;
tot_offset
+=
problem_sizes1
[
i
*
3
];
expert_offsets
[
i
+
1
]
=
tot_offset
;
tot_offset_round
+=
(
problem_sizes1
[
i
*
3
]
+
(
128
-
1
))
/
128
*
128
;
blockscale_offsets
[
i
+
1
]
=
tot_offset_round
;
}
}
__global__
void
compute_arg_sorts
(
const
uint32_t
*
__restrict__
topk_ids
,
const
int32_t
*
__restrict__
expert_offsets
,
int32_t
*
input_permutation
,
int32_t
*
output_permutation
,
...
...
@@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
...
...
@@ -85,19 +103,61 @@ void get_cutlass_moe_mm_data_caller(
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
compute_problem_sizes
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
u
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
if
(
blockscale_offsets
.
has_value
())
{
compute_expert_blockscale_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
blockscale_offsets
.
value
().
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
}
else
{
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
}
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
u
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
topk_ids
.
size
(
1
));
}
__global__
void
compute_pplx_data
(
int32_t
*
expert_offsets
,
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes2
,
const
int32_t
*
__restrict__
expert_num_tokens
,
const
int
padded_m
,
const
int
n
,
const
int
k
)
{
int
expert_idx
=
threadIdx
.
x
;
expert_offsets
[
expert_idx
]
=
expert_idx
*
padded_m
;
problem_sizes1
[
expert_idx
*
3
]
=
expert_num_tokens
[
expert_idx
];
problem_sizes1
[
expert_idx
*
3
+
1
]
=
2
*
n
;
problem_sizes1
[
expert_idx
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_idx
*
3
]
=
expert_num_tokens
[
expert_idx
];
problem_sizes2
[
expert_idx
*
3
+
1
]
=
k
;
problem_sizes2
[
expert_idx
*
3
+
2
]
=
n
;
}
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
expert_offsets
.
device
().
index
());
compute_pplx_data
<<<
1
,
num_local_experts
,
0
,
stream
>>>
(
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
expert_num_tokens
.
data_ptr
()),
padded_m
,
n
,
k
);
}
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
View file @
4c676e3d
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
...
...
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
TORCH_CHECK
(
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell"
);
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently, only fp8 gemm is implemented for Blackwell"
);
vllm
::
cutlass_scaled_mm_sm100_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
vllm
::
cutlass_scaled_mm_sm100_fp8
,
nullptr
,
// int8 not supported on SM100
vllm
::
cutlass_scaled_mm_blockwise_sm100_fp8
);
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
View file @
4c676e3d
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
...
...
@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
{
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
vllm
::
cutlass_scaled_mm_sm90_fp8
,
vllm
::
cutlass_scaled_mm_sm90_int8
,
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
);
}
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
4c676e3d
...
...
@@ -29,19 +29,15 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void
cutlass_moe_mm_sm90
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
);
#endif
...
...
@@ -53,6 +49,24 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -110,6 +124,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
&&
cuda_device_capability
<
100
)
{
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
100
)
{
return
CUDA_VERSION
>=
12080
;
}
#endif
...
...
@@ -117,7 +133,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
}
bool
cutlass_group_gemm_supported
(
int64_t
cuda_device_capability
)
{
// CUTLASS groped FP8 kernels need at least CUDA 12.3
// CUTLASS gro
u
ped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)
#if defined CUDA_VERSION
...
...
@@ -200,12 +216,13 @@ void cutlass_moe_mm(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
c_strides
,
per_act_token
,
per_out_ch
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
@@ -218,14 +235,17 @@ void get_cutlass_moe_mm_data(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
);
output_permutation
,
num_experts
,
n
,
k
,
blockscale_offsets
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
...
...
@@ -235,6 +255,29 @@ void get_cutlass_moe_mm_data(
version_num
,
". Required capability: 90"
);
}
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
)
{
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
get_cutlass_pplx_moe_mm_data_caller
(
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
expert_num_tokens
,
num_local_experts
,
padded_m
,
n
,
k
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: "
,
version_num
,
". Required capability: 90"
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
0 → 100644
View file @
4c676e3d
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using
namespace
cute
;
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementSF
,
typename
ElementAccumulator
,
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
__global__
void
__get_group_gemm_starts
(
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementSF
**
a_scales_offsets
,
ElementSF
**
b_scales_offsets
,
ElementAccumulator
**
alpha_offsets
,
LayoutSFA
*
layout_sfa_base_as_int
,
LayoutSFB
*
layout_sfb_base_as_int
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementSF
*
a_scales_base_as_int
,
ElementSF
*
b_scales_base_as_int
,
ElementAccumulator
*
alphas_base_as_int
,
const
int32_t
*
expert_offsets
,
const
int32_t
*
sf_offsets
,
const
int32_t
*
problem_sizes_as_shapes
,
const
int
K
,
const
int
N
)
{
int64_t
expert_id
=
threadIdx
.
x
;
if
(
expert_id
>=
gridDim
.
x
*
blockDim
.
x
)
{
return
;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t
expert_offset
=
static_cast
<
int64_t
>
(
expert_offsets
[
expert_id
]);
int64_t
sf_offset
=
static_cast
<
int64_t
>
(
sf_offsets
[
expert_id
]);
// size for block in block scale.
int64_t
group_size
=
16
;
int64_t
m
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
]);
int64_t
n
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
1
]);
int64_t
k
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
2
]);
assert
((
m
>=
0
&&
n
==
N
&&
k
==
K
&&
k
%
2
==
0
)
&&
"unexpected problem sizes"
);
int64_t
half_k
=
static_cast
<
int64_t
>
(
k
/
2
);
int64_t
group_k
=
static_cast
<
int64_t
>
(
k
/
group_size
);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
half_k
;
b_offsets
[
expert_id
]
=
b_base_as_int
+
expert_id
*
n
*
half_k
;
// Shape of C = [M, N]
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
sf_offset
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
a_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
expert_id
*
n
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
b_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of alpha = [E]
alpha_offsets
[
expert_id
]
=
alphas_base_as_int
+
expert_id
;
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base_as_int
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base_as_int
+
expert_id
;
*
layout_sfa_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
*
layout_sfb_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
}
template
<
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
void
run_get_group_gemm_starts
(
const
torch
::
Tensor
&
a_starts
,
const
torch
::
Tensor
&
b_starts
,
const
torch
::
Tensor
&
out_starts
,
const
torch
::
Tensor
&
a_scales_starts
,
const
torch
::
Tensor
&
b_scales_starts
,
const
torch
::
Tensor
&
alpha_starts
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
/*these are used for their base addresses*/
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
alphas
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
sf_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
int
M
,
int
N
,
int
K
)
{
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
TORCH_CHECK
(
out_tensors
.
size
(
1
)
==
N
,
"Output tensor shape doesn't match expected shape"
);
TORCH_CHECK
(
K
/
2
==
b_tensors
.
size
(
2
),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match"
);
if
(
false
)
{
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kFloat16
,
half
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
template
<
typename
OutType
>
void
run_fp4_blockwise_scaled_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
,
int
M
,
int
N
,
int
K
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int32_t
,
int32_t
,
int32_t
>>
;
using
ElementType
=
cutlass
::
float_e2m1_t
;
using
ElementSFType
=
cutlass
::
float_ue4m3_t
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementC
=
OutType
;
using
ElementD
=
ElementC
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
LayoutC
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
32
;
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
EpilogueOperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Epilogue Operator class tag
using
MainloopOperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Mainloop Operator class tag
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized based
// on the tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
struct
MMA1SMConfig
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100
;
// Kernel to launch
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
// Epilogue to launch
};
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
EpilogueOperatorClass
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
Shape
<
_128
,
_64
>
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentD
,
typename
MMA1SMConfig
::
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
MainloopOperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
typename
MMA1SMConfig
::
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm1SM
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
Gemm1SM
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
full
({
num_experts
},
output
.
stride
(
0
),
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
full
({
num_experts
},
a
.
stride
(
0
)
*
2
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
full
({
num_experts
},
b
.
stride
(
1
)
*
2
,
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Set the Scheduler info
cutlass
::
KernelHardwareInfo
hw_info
;
using
RasterOrderOptions
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm100GroupParams
<
typename
ProblemShape
::
UnderlyingProblemShape
>::
RasterOrderOptions
;
typename
Gemm
::
GemmKernel
::
TileSchedulerArguments
scheduler
;
scheduler
.
raster_order
=
RasterOrderOptions
::
AlongM
;
hw_info
.
device_id
=
a
.
get_device
();
static
std
::
unordered_map
<
int
,
int
>
cached_sm_counts
;
if
(
cached_sm_counts
.
find
(
hw_info
.
device_id
)
==
cached_sm_counts
.
end
())
{
cached_sm_counts
[
hw_info
.
device_id
]
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
hw_info
.
sm_count
=
min
(
cached_sm_counts
[
hw_info
.
device_id
],
INT_MAX
);
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
a_strides1
.
data_ptr
()),
static_cast
<
const
ElementType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
b_strides1
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
())};
auto
&
fusion_args
=
epilogue_args
.
thread
;
fusion_args
.
alpha_ptr_array
=
reinterpret_cast
<
float
**>
(
alpha_ptrs
.
data_ptr
());
fusion_args
.
dAlpha
=
{
_0
{},
_0
{},
1
};
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
// Input validation
CHECK_INPUT
(
a
,
FLOAT4_E2M1X2
,
"a"
);
CHECK_INPUT
(
b
,
FLOAT4_E2M1X2
,
"b"
);
CHECK_INPUT
(
a_blockscale
,
SF_DTYPE
,
"a_blockscale"
);
CHECK_INPUT
(
b_blockscales
,
SF_DTYPE
,
"b_blockscales"
);
CHECK_INPUT
(
alphas
,
at
::
ScalarType
::
Float
,
"alphas"
);
TORCH_CHECK
(
a_blockscale
.
dim
()
==
2
,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: "
,
a_blockscale
.
dim
())
TORCH_CHECK
(
b_blockscales
.
dim
()
==
3
,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: "
,
b_blockscales
.
dim
())
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be a 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have the shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32."
);
int
M
=
static_cast
<
int
>
(
a
.
size
(
0
));
int
N
=
static_cast
<
int
>
(
b
.
size
(
1
));
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
bfloat16_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
else
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above."
);
#endif
}
csrc/quantization/fp4/nvfp4_experts_quant.cu
0 → 100644
View file @
4c676e3d
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts.
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
if
(
rowIdx
>=
input_offset_by_experts
[
i
]
&&
rowIdx
<
input_offset_by_experts
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
input_offset_by_experts
[
i
];
expert_idx
=
i
;
break
;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
void
quant_impl
(
void
*
output
,
void
*
output_scale
,
void
*
input
,
void
*
input_global_scale
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
int
m_topk
,
int
k
,
int
n_experts
,
cudaStream_t
stream
)
{
// TODO: this multiProcessorCount should be cached.
int
device
;
cudaGetDevice
(
&
device
);
int
multiProcessorCount
;
cudaDeviceGetAttribute
(
&
multiProcessorCount
,
cudaDevAttrMultiProcessorCount
,
device
);
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
k
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m_topk
),
multiProcessorCount
*
numBlocksPerSM
));
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr
auto
HALF
=
at
::
ScalarType
::
Half
;
constexpr
auto
BF16
=
at
::
ScalarType
::
BFloat16
;
constexpr
auto
FLOAT
=
at
::
ScalarType
::
Float
;
constexpr
auto
INT
=
at
::
ScalarType
::
Int
;
constexpr
auto
UINT8
=
at
::
ScalarType
::
Byte
;
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts must be a CUDA tensor"
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK
(
output
.
scalar_type
()
==
UINT8
);
TORCH_CHECK
(
output_scale
.
scalar_type
()
==
INT
);
const
int
BLOCK_SIZE
=
16
;
auto
m_topk
=
input
.
size
(
0
);
auto
k
=
input
.
size
(
1
);
TORCH_CHECK
(
k
%
BLOCK_SIZE
==
0
,
"k must be a multiple of 16"
);
auto
n_experts
=
input_global_scale
.
size
(
0
);
TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
int
scales_k
=
k
/
BLOCK_SIZE
;
// 4 means the swizzle requirement by nvidia nvfp4.
int
padded_k
=
(
scales_k
+
(
4
-
1
))
/
4
*
4
;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
quant_impl
<
half
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
quant_impl
<
__nv_bfloat16
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
}
\ No newline at end of file
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
4c676e3d
...
...
@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch
::
Tensor
const
&
input_sf
);
#endif
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
}
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_experts_quant_sm100a
(
output
,
output_scale
,
input
,
input_global_scale
,
input_offset_by_experts
,
output_scale_offset_by_experts
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
csrc/quantization/fp8/common.cu
View file @
4c676e3d
...
...
@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// aligned at
8
-byte and
4
-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
// aligned at
32
-byte and
16
-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
16
==
0
;
float
absmax_val
=
0.0
f
;
if
(
can_vectorize
)
{
...
...
@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
const
x
=
static_cast
<
float
>
(
token_input
[
i
]);
absmax_val
=
max
(
absmax_val
,
fabs
(
x
));
absmax_val
=
f
max
f
(
absmax_val
,
fabs
f
(
x
));
}
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
float
const
block_absmax_val_maybe
=
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
token_scale
;
if
(
tid
==
0
)
{
if
(
scale_ub
)
{
token_scale
=
min
(
block_absmax_val_maybe
,
*
scale_ub
);
token_scale
=
f
min
f
(
block_absmax_val_maybe
,
*
scale_ub
);
}
else
{
token_scale
=
block_absmax_val_maybe
;
}
// token scale computation
token_scale
=
max
(
token_scale
/
quant_type_max_v
<
fp8_type
>
,
min_scaling_factor
<
fp8_type
>::
val
());
token_scale
=
f
max
f
(
token_scale
/
quant_type_max_v
<
fp8_type
>
,
min_scaling_factor
<
fp8_type
>::
val
());
scale
[
token_idx
]
=
token_scale
;
}
__syncthreads
();
...
...
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
scale
)
// [1]
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
dim3
grid
(
num_tokens
);
dim3
block
(
1024
);
int
const
block_size
=
256
;
int
const
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
const
num_elems
=
input
.
numel
();
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
block_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
...
...
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int64_t
num_elems
=
input
.
numel
();
dim3
grid
(
num_tokens
);
dim3
block
(
1024
);
int
const
block_size
=
256
;
int
const
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
const
num_elems
=
input
.
numel
();
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
block_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
...
...
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
const
block_size
=
256
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
const
block
(
std
::
min
(
hidden_size
,
block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment