Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c676e3d
Commit
4c676e3d
authored
Jun 20, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.1' into v0.9.1-dev
parents
b4c4464d
b6553be1
Changes
418
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1721 additions
and
179 deletions
+1721
-179
csrc/ops.h
csrc/ops.h
+65
-6
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+3
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+80
-52
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+121
-0
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+14
-2
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
...ization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
+23
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
+279
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
+75
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+5
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
+51
-2
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
+19
-10
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
+2
-4
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
+69
-9
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
+5
-17
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
+5
-46
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+57
-14
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+402
-0
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+404
-0
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+23
-1
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+19
-16
No files found.
Too many changes to show.
To preserve performance only
418 of 418+
files are displayed.
Plain diff
Email patch
csrc/ops.h
View file @
4c676e3d
...
@@ -179,6 +179,31 @@ void merge_attn_states(torch::Tensor& output,
...
@@ -179,6 +179,31 @@ void merge_attn_states(torch::Tensor& output,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
);
const
torch
::
Tensor
&
suffix_lse
);
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
);
void
convert_vertical_slash_indexes_mergehead
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
torch
::
Tensor
vertical_indices_count
,
// [N_HEADS, ]
torch
::
Tensor
slash_indices_count
,
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
);
#endif
#endif
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
...
@@ -193,6 +218,11 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
...
@@ -193,6 +218,11 @@ void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weigh
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
repetition_penalties
);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
// double epsilon);
...
@@ -212,13 +242,13 @@ void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
...
@@ -212,13 +242,13 @@ void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
// std::optional<torch::Tensor> residual);
// std::optional<torch::Tensor> residual);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
int64_t
rot_dim
,
bool
is_neox
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
rotary_embedding_tgi
(
void
rotary_embedding_tgi
(
torch
::
Tensor
&
query
,
torch
::
Tensor
&
query
,
...
@@ -230,6 +260,9 @@ void rotary_embedding_tgi(
...
@@ -230,6 +260,9 @@ void rotary_embedding_tgi(
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
@@ -316,6 +349,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
...
@@ -316,6 +349,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
torch
::
Tensor
ggml_moe_a8_vec
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
topk_ids
,
int64_t
top_k
,
int64_t
type
,
int64_t
row
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
#ifndef USE_ROCM
#ifndef USE_ROCM
...
@@ -340,13 +377,29 @@ void cutlass_moe_mm(
...
@@ -340,13 +377,29 @@ void cutlass_moe_mm(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
);
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
);
void
get_cutlass_moe_mm_data
(
void
get_cutlass_moe_mm_data
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
...
@@ -369,6 +422,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
...
@@ -369,6 +422,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
torch
::
Tensor
const
&
input_scale
);
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/opt/activation_kernels_opt.cu
View file @
4c676e3d
...
@@ -161,6 +161,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
...
@@ -161,6 +161,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
}
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
\
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
\
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
\
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
\
VLLM_DISPATCH_FLOATING_TYPES
(
\
VLLM_DISPATCH_FLOATING_TYPES
(
\
...
...
csrc/pos_encoding_kernels.cu
View file @
4c676e3d
...
@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding(
...
@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int64_t
head_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding(
...
@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding(
const
int
nq
=
num_heads
*
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_stride
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
if
(
key
!=
nullptr
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
nk
=
num_kv_heads
*
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int64_t
token_head
=
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
token_idx
*
key_stride
+
head_idx
*
head_stride
;
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
}
}
}
...
@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel(
...
@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
...
@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel(
...
@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
...
@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel(
...
@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
// 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
// or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
...
@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
...
@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
}
// namespace vllm
}
// namespace vllm
...
@@ -127,10 +136,12 @@ void rotary_embedding(
...
@@ -127,10 +136,12 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
std
::
optional
<
torch
::
Tensor
>
key
,
// [num_tokens, num_kv_heads * head_size] or
// null or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
bool
is_neox
)
{
...
@@ -138,40 +149,46 @@ void rotary_embedding(
...
@@ -138,40 +149,46 @@ void rotary_embedding(
int64_t
num_tokens
=
positions
.
numel
();
int64_t
num_tokens
=
positions
.
numel
();
int
positions_ndim
=
positions
.
dim
();
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
"query, key and positions must have the same number of tokens"
);
}
}
if
(
positions_ndim
==
2
)
{
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
"query, key and positions must have the same batch_size and seq_len"
);
}
}
// Make sure head_size is valid for query and key
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
// hidden_size = num_heads * head_size
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have consistent number of heads
// Make sure query and key have consistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
@@ -181,15 +198,16 @@ void rotary_embedding(
...
@@ -181,15 +198,16 @@ void rotary_embedding(
if
(
is_neox
)
{
if
(
is_neox
)
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
}
});
});
}
}
...
@@ -204,10 +222,12 @@ void batched_rotary_embedding(
...
@@ -204,10 +222,12 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
std
::
optional
<
torch
::
Tensor
>
// [num_tokens, num_kv_heads * head_size] or
key
,
// null or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int64_t
rot_dim
,
bool
is_neox
,
int64_t
rot_dim
,
...
@@ -221,38 +241,44 @@ void batched_rotary_embedding(
...
@@ -221,38 +241,44 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets"
);
"cos_sin_cache_offsets"
);
int
positions_ndim
=
positions
.
dim
();
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
"query, key and positions must have the same number of tokens"
);
}
}
if
(
positions_ndim
==
2
)
{
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
"query, key and positions must have the same batch_size and seq_len"
);
}
}
// Make sure head_size is valid for query and key
// Make sure head_size is valid for query and key
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have concistent number of heads
// Make sure query and key have concistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
@@ -263,16 +289,18 @@ void batched_rotary_embedding(
...
@@ -263,16 +289,18 @@ void batched_rotary_embedding(
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
}
});
});
}
}
csrc/quantization/activation_kernels.cu
0 → 100644
View file @
4c676e3d
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
typename
fp8_type
>
__global__
void
act_and_mul_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
float
*
scale
,
const
int
d
)
{
const
int32_t
blocks_per_token
=
gridDim
.
y
;
const
int32_t
elems_per_128bit_load
=
(
128
/
8
)
/
sizeof
(
scalar_t
);
// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const
int32_t
tgt_elems_per_block
=
div_ceil
(
d
,
blocks_per_token
);
const
int32_t
elems_per_block
=
round_to_next_multiple_of
(
tgt_elems_per_block
,
elems_per_128bit_load
);
const
int32_t
block_start
=
blockIdx
.
y
*
elems_per_block
;
int32_t
block_end
=
block_start
+
elems_per_block
;
block_end
=
block_end
>
d
?
d
:
block_end
;
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const
int64_t
token_idx
=
blockIdx
.
x
;
const
scalar_t
*
__restrict__
x_ptr
=
input
+
token_idx
*
2
*
d
;
const
scalar_t
*
__restrict__
y_ptr
=
input
+
token_idx
*
2
*
d
+
d
;
fp8_type
*
__restrict__
out_ptr
=
out
+
token_idx
*
d
;
// 128-bit vectorized code
const
int32_t
vec_loop_end
=
round_to_previous_multiple_of
(
elems_per_128bit_load
,
block_end
);
const
int32_t
vec_end_idx
=
vec_loop_end
/
elems_per_128bit_load
;
const
int32_t
vec_start_idx
=
block_start
/
elems_per_128bit_load
;
const
int4
*
__restrict__
x_128bit_ptr
=
reinterpret_cast
<
const
int4
*>
(
x_ptr
);
const
int4
*
__restrict__
y_128bit_ptr
=
reinterpret_cast
<
const
int4
*>
(
y_ptr
);
int2
*
__restrict__
out_128bit_ptr
=
reinterpret_cast
<
int2
*>
(
out_ptr
);
float
inverted_scale
=
1
/
*
scale
;
#pragma unroll
for
(
int32_t
vec_idx
=
vec_start_idx
+
threadIdx
.
x
;
vec_idx
<
vec_end_idx
;
vec_idx
+=
blockDim
.
x
)
{
const
int4
x_128bit
=
VLLM_LDG
(
&
x_128bit_ptr
[
vec_idx
]);
const
int4
y_128bit
=
VLLM_LDG
(
&
y_128bit_ptr
[
vec_idx
]);
using
scalar_128bit_vec_t
=
std
::
array
<
scalar_t
,
elems_per_128bit_load
>
;
using
scalar_64bit_vec_t
=
std
::
array
<
fp8_type
,
elems_per_128bit_load
>
;
scalar_64bit_vec_t
out_vec
;
const
auto
x_vec
=
reinterpret_cast
<
scalar_128bit_vec_t
const
&>
(
x_128bit
);
const
auto
y_vec
=
reinterpret_cast
<
scalar_128bit_vec_t
const
&>
(
y_128bit
);
#pragma unroll
for
(
int
i
=
0
;
i
<
elems_per_128bit_load
;
i
++
)
{
out_vec
[
i
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
ACT_FN
(
x_vec
[
i
])
*
y_vec
[
i
],
inverted_scale
);
}
out_128bit_ptr
[
vec_idx
]
=
reinterpret_cast
<
const
int2
&>
(
out_vec
);
}
// Scalar cleanup code
if
(
block_end
>
vec_loop_end
)
{
for
(
int64_t
idx
=
vec_loop_end
+
threadIdx
.
x
;
idx
<
block_end
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
x_ptr
[
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
y_ptr
[
idx
]);
out_ptr
[
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
ACT_FN
(
x
)
*
y
,
inverted_scale
);
}
}
}
}
// namespace vllm
// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
fp8_t> \
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), d); \
}); \
});
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., 2 * d]
torch
::
Tensor
&
scale
)
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
out
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
2
==
0
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
4c676e3d
...
@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
...
@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float
dst
=
std
::
nearbyint
(
x
);
float
dst
=
std
::
nearbyint
(
x
);
// saturate
// saturate
dst
=
std
::
clamp
(
dst
,
i8_min
,
i8_max
);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
dst
=
(
dst
<
i8_min
)
?
i8_min
:
(
dst
>
i8_max
)
?
i8_max
:
dst
;
return
static_cast
<
int8_t
>
(
dst
);
return
static_cast
<
int8_t
>
(
dst
);
#else
#else
// CUDA path
// CUDA path
...
@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
...
@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// saturate
// saturate
int32_t
dst
=
std
::
clamp
(
x
,
i8_min
,
i8_max
);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// int32_t dst = std::clamp(x, i8_min, i8_max);
int32_t
dst
=
(
x
<
i8_min
)
?
i8_min
:
(
x
>
i8_max
)
?
i8_max
:
x
;
return
static_cast
<
int8_t
>
(
dst
);
return
static_cast
<
int8_t
>
(
dst
);
#else
#else
// CUDA path
// CUDA path
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
0 → 100644
View file @
4c676e3d
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
0 → 100644
View file @
4c676e3d
#pragma once
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
// clang-format off
template
<
class
OutType
,
int
ScaleGranularityM
,
int
ScaleGranularityN
,
int
ScaleGranularityK
,
class
MmaTileShape
,
class
ClusterShape
,
class
EpilogueScheduler
,
class
MainloopScheduler
,
bool
swap_ab_
=
false
>
struct
cutlass_3x_gemm_fp8_blockwise
{
static
constexpr
bool
swap_ab
=
swap_ab_
;
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutA_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutA
>::
type
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutB_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutB
>::
type
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutD
>::
type
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
// TODO: support bias
using
LayoutC
=
LayoutD
;
using
LayoutC_Transpose
=
LayoutD_Transpose
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ScaleConfig
=
conditional_t
<
swap_ab
,
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
MN
>
,
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>>
;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
ElementScalar
=
float
;
using
DefaultOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementCompute
,
ElementC
,
ElementScalar
,
RoundStyle
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
MmaTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
conditional_t
<
swap_ab
,
LayoutC_Transpose
,
LayoutC
>
,
AlignmentC
,
ElementD
,
conditional_t
<
swap_ab
,
LayoutD_Transpose
,
LayoutD
>
,
AlignmentD
,
EpilogueScheduler
,
DefaultOperation
>::
CollectiveOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
CollectiveMainloop
=
conditional_t
<
swap_ab
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementB
,
cute
::
tuple
<
LayoutB_Transpose
,
LayoutSFA
>
,
AlignmentB
,
ElementA
,
cute
::
tuple
<
LayoutA_Transpose
,
LayoutSFB
>
,
AlignmentA
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
>
;
using
KernelType
=
enable_sm100_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
static
constexpr
bool
swap_ab
=
Gemm
::
swap_ab
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutSFA
=
typename
Gemm
::
LayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
LayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
ScaleConfig
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
StrideA
a_stride
;
StrideB
b_stride
;
StrideC
c_stride
;
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
swap_ab
?
cute
::
make_shape
(
n
,
m
,
1
)
:
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
auto
mainloop_args
=
[
&
](){
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
if
(
swap_ab
)
{
return
typename
GemmKernel
::
MainloopArguments
{
b_ptr
,
b_stride
,
a_ptr
,
a_stride
,
b_scales_ptr
,
layout_SFA
,
a_scales_ptr
,
layout_SFB
};
}
else
{
return
typename
GemmKernel
::
MainloopArguments
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
}
}();
auto
prob_shape
=
swap_ab
?
cute
::
make_shape
(
n
,
m
,
k
,
1
)
:
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm100_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
),
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
a
.
get_device
());
constexpr
int
TILE_K
=
128
;
// TODO: better heuristics
bool
swap_ab
=
(
m
<
16
)
||
(
m
%
4
!=
0
);
bool
use_tma_epilogue
=
(
m
*
n
)
%
4
==
0
;
if
(
!
swap_ab
)
{
constexpr
int
TILE_N
=
128
;
int
tile_m
=
256
;
if
(
cuda_utils
::
ceil_div
(
n
,
TILE_N
)
*
cuda_utils
::
ceil_div
(
m
,
64
)
<=
sms
)
{
tile_m
=
64
;
}
else
if
(
cuda_utils
::
ceil_div
(
n
,
TILE_N
)
*
cuda_utils
::
ceil_div
(
m
,
128
)
<=
sms
)
{
tile_m
=
128
;
}
if
(
tile_m
==
64
)
{
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_64
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_64
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
if
(
tile_m
==
128
)
{
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_128
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_128
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
{
// tile_m == 256
if
(
use_tma_epilogue
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_256
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
TILE_N
,
TILE_K
,
Shape
<
_256
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
else
{
// TODO: Test more tile N configs
constexpr
int
TILE_M
=
128
;
constexpr
int
TILE_N
=
16
;
// TMA epilogue isn't compatible with Swap A/B
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
TILE_M
,
1
,
TILE_K
,
Shape
<
Int
<
TILE_M
>
,
Int
<
TILE_N
>
,
Int
<
TILE_K
>>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
NoSmemWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
,
true
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
0 → 100644
View file @
4c676e3d
#include <torch/all.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template
<
typename
Fp8Func
,
typename
Int8Func
,
typename
BlockwiseFunc
>
void
dispatch_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
,
Fp8Func
fp8_func
,
Int8Func
int8_func
,
BlockwiseFunc
blockwise_func
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
fp8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
if
constexpr
(
!
std
::
is_same_v
<
Int8Func
,
std
::
nullptr_t
>
)
{
int8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
false
,
"Int8 not supported for this architecture"
);
}
}
}
else
{
TORCH_CHECK
(
a_scales
.
dim
()
==
2
,
"a scale must be 2d tensor."
);
TORCH_CHECK
(
b_scales
.
dim
()
==
2
,
"b scale must be 2d tensor."
);
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
100
)
{
TORCH_CHECK
(
a
.
size
(
0
)
==
a_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
a
.
size
(
1
),
int64_t
(
128
))
==
a_scales
.
size
(
1
),
"a_scale_group_shape must be [1, 128]."
);
TORCH_CHECK
(
cuda_utils
::
ceil_div
(
b
.
size
(
0
),
int64_t
(
128
))
==
b_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
b
.
size
(
1
),
int64_t
(
128
))
==
b_scales
.
size
(
1
),
"b_scale_group_shape must be [128, 128]."
);
}
else
{
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
}
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
blockwise_func
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
View file @
4c676e3d
...
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
...
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
View file @
4c676e3d
...
@@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller;
...
@@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_default
{
struct
sm100_fp8_config_default
{
// M in (128, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
...
@@ -25,6 +26,34 @@ struct sm100_fp8_config_default {
...
@@ -25,6 +26,34 @@ struct sm100_fp8_config_default {
KernelSchedule
,
EpilogueSchedule
>
;
KernelSchedule
,
EpilogueSchedule
>
;
};
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
typename
...
EpilogueArgs
>
...
@@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
...
@@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
using
Cutlass3xGemmDefault
=
using
Cutlass3xGemmDefault
=
typename
sm100_fp8_config_default
<
InType
,
OutType
,
typename
sm100_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
using
Cutlass3xGemmM64
=
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
typename
sm100_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm100_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
// m in [1, 64]
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
// m in (128, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
...
...
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
View file @
4c676e3d
...
@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
...
@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
TORCH_CHECK
(
a_tensors
.
size
(
0
)
>
0
,
"No input A tensors provided."
);
TORCH_CHECK
(
a_tensors
.
size
(
0
)
>
0
,
"No input A tensors provided."
);
TORCH_CHECK
(
b_tensors
.
size
(
0
)
>
0
,
"No input B tensors provided."
);
TORCH_CHECK
(
b_tensors
.
size
(
0
)
>
0
,
"No input B tensors provided."
);
TORCH_CHECK
(
out_tensors
.
size
(
0
)
>
0
,
"No output tensors provided."
);
TORCH_CHECK
(
out_tensors
.
size
(
0
)
>
0
,
"No output tensors provided."
);
...
@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
...
@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if
(
n
>=
8192
)
{
if
(
n
>=
8192
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmN8192
>
(
cutlass_group_gemm_caller
<
Cutlass3xGemmN8192
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
if
(
k
>=
8192
)
{
}
else
if
(
k
>=
8192
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmK8192
>
(
cutlass_group_gemm_caller
<
Cutlass3xGemmK8192
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
if
(
m
<=
16
)
{
}
else
if
(
m
<=
16
)
{
cutlass_group_gemm_caller
<
Cutlass3xGemmM16
>
(
cutlass_group_gemm_caller
<
Cutlass3xGemmM16
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
{
}
else
{
cutlass_group_gemm_caller
<
Cutlass3xGemmDefault
>
(
cutlass_group_gemm_caller
<
Cutlass3xGemmDefault
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
}
}
}
...
@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
...
@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
if
(
out_tensors
.
dtype
()
==
torch
::
kBFloat16
)
{
if
(
out_tensors
.
dtype
()
==
torch
::
kBFloat16
)
{
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
>
(
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
else
{
}
else
{
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
>
(
run_cutlass_moe_mm_sm90
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
>
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
per_act_token
,
per_out_ch
);
}
}
}
}
...
@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
...
@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
dispatch_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
dispatch_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
c_strides
,
per_act_token
,
per_out_ch
);
}
}
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
View file @
4c676e3d
...
@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
...
@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementD
=
typename
Gemm
::
ElementD
;
...
@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
...
@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
int
k_size
=
a_tensors
.
size
(
1
);
int
k_size
=
a_tensors
.
size
(
1
);
int
n_size
=
out_tensors
.
size
(
1
);
int
n_size
=
out_tensors
.
size
(
1
);
bool
per_act_token
=
a_scales
.
numel
()
!=
1
;
bool
per_out_ch
=
b_scales
.
numel
()
!=
num_experts
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
auto
options_int
=
auto
options_int
=
...
...
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
View file @
4c676e3d
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
constexpr
uint64_t
THREADS_PER_EXPERT
=
512
;
__global__
void
compute_problem_sizes
(
const
int
*
__restrict__
topk_ids
,
__global__
void
compute_problem_sizes
(
const
u
int
32_t
*
__restrict__
topk_ids
,
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes2
,
int32_t
*
problem_sizes2
,
int32_t
*
atomic_buffer
,
int32_t
*
atomic_buffer
,
...
@@ -45,7 +45,24 @@ __global__ void compute_expert_offsets(
...
@@ -45,7 +45,24 @@ __global__ void compute_expert_offsets(
}
}
}
}
__global__
void
compute_arg_sorts
(
const
int
*
__restrict__
topk_ids
,
__global__
void
compute_expert_blockscale_offsets
(
const
int32_t
*
__restrict__
problem_sizes1
,
int32_t
*
expert_offsets
,
int32_t
*
blockscale_offsets
,
int32_t
*
atomic_buffer
,
const
int
num_experts
)
{
int32_t
tot_offset
=
0
;
int32_t
tot_offset_round
=
0
;
expert_offsets
[
0
]
=
0
;
blockscale_offsets
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
atomic_buffer
[
i
]
=
tot_offset
;
tot_offset
+=
problem_sizes1
[
i
*
3
];
expert_offsets
[
i
+
1
]
=
tot_offset
;
tot_offset_round
+=
(
problem_sizes1
[
i
*
3
]
+
(
128
-
1
))
/
128
*
128
;
blockscale_offsets
[
i
+
1
]
=
tot_offset_round
;
}
}
__global__
void
compute_arg_sorts
(
const
uint32_t
*
__restrict__
topk_ids
,
const
int32_t
*
__restrict__
expert_offsets
,
const
int32_t
*
__restrict__
expert_offsets
,
int32_t
*
input_permutation
,
int32_t
*
input_permutation
,
int32_t
*
output_permutation
,
int32_t
*
output_permutation
,
...
@@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
...
@@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
topk_ids
.
device
().
index
());
auto
options_int32
=
auto
options_int32
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
topk_ids
.
device
());
...
@@ -85,19 +103,61 @@ void get_cutlass_moe_mm_data_caller(
...
@@ -85,19 +103,61 @@ void get_cutlass_moe_mm_data_caller(
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
int
num_threads
=
min
(
THREADS_PER_EXPERT
,
topk_ids
.
numel
());
compute_problem_sizes
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
compute_problem_sizes
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
u
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
n
,
k
);
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
if
(
blockscale_offsets
.
has_value
())
{
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
compute_expert_blockscale_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
blockscale_offsets
.
value
().
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
}
else
{
compute_expert_offsets
<<<
1
,
1
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
num_experts
);
}
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
compute_arg_sorts
<<<
num_experts
,
num_threads
,
0
,
stream
>>>
(
static_cast
<
const
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
u
int32_t
*>
(
topk_ids
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
input_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
output_permutation
.
data_ptr
()),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
static_cast
<
int32_t
*>
(
atomic_buffer
.
data_ptr
()),
topk_ids
.
numel
(),
topk_ids
.
size
(
1
));
topk_ids
.
size
(
1
));
}
}
__global__
void
compute_pplx_data
(
int32_t
*
expert_offsets
,
int32_t
*
problem_sizes1
,
int32_t
*
problem_sizes2
,
const
int32_t
*
__restrict__
expert_num_tokens
,
const
int
padded_m
,
const
int
n
,
const
int
k
)
{
int
expert_idx
=
threadIdx
.
x
;
expert_offsets
[
expert_idx
]
=
expert_idx
*
padded_m
;
problem_sizes1
[
expert_idx
*
3
]
=
expert_num_tokens
[
expert_idx
];
problem_sizes1
[
expert_idx
*
3
+
1
]
=
2
*
n
;
problem_sizes1
[
expert_idx
*
3
+
2
]
=
k
;
problem_sizes2
[
expert_idx
*
3
]
=
expert_num_tokens
[
expert_idx
];
problem_sizes2
[
expert_idx
*
3
+
1
]
=
k
;
problem_sizes2
[
expert_idx
*
3
+
2
]
=
n
;
}
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
expert_offsets
.
device
().
index
());
compute_pplx_data
<<<
1
,
num_local_experts
,
0
,
stream
>>>
(
static_cast
<
int32_t
*>
(
expert_offsets
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes1
.
data_ptr
()),
static_cast
<
int32_t
*>
(
problem_sizes2
.
data_ptr
()),
static_cast
<
const
int32_t
*>
(
expert_num_tokens
.
data_ptr
()),
padded_m
,
n
,
k
);
}
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
View file @
4c676e3d
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
NVIDIA GPUs with sm100 (Blackwell).
...
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
...
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
vllm
::
cutlass_scaled_mm_sm100_fp8
,
nullptr
,
// int8 not supported on SM100
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
vllm
::
cutlass_scaled_mm_blockwise_sm100_fp8
);
TORCH_CHECK
(
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell"
);
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently, only fp8 gemm is implemented for Blackwell"
);
vllm
::
cutlass_scaled_mm_sm100_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
#endif
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
View file @
4c676e3d
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
NVIDIA GPUs with sm90a (Hopper).
...
@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
vllm
::
cutlass_scaled_mm_sm90_fp8
,
vllm
::
cutlass_scaled_mm_sm90_int8
,
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
{
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
4c676e3d
...
@@ -29,19 +29,15 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -29,19 +29,15 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void
cutlass_moe_mm_sm90
(
void
cutlass_moe_mm_sm90
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
);
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
#endif
#endif
...
@@ -53,6 +49,24 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
...
@@ -53,6 +49,24 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
);
void
get_cutlass_pplx_moe_mm_data_caller
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
...
@@ -110,6 +124,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
...
@@ -110,6 +124,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
&&
cuda_device_capability
<
100
)
{
if
(
cuda_device_capability
>=
90
&&
cuda_device_capability
<
100
)
{
return
CUDA_VERSION
>=
12000
;
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
100
)
{
return
CUDA_VERSION
>=
12080
;
}
}
#endif
#endif
...
@@ -117,7 +133,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
...
@@ -117,7 +133,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
}
}
bool
cutlass_group_gemm_supported
(
int64_t
cuda_device_capability
)
{
bool
cutlass_group_gemm_supported
(
int64_t
cuda_device_capability
)
{
// CUTLASS groped FP8 kernels need at least CUDA 12.3
// CUTLASS gro
u
ped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)
// and SM90 (Hopper)
#if defined CUDA_VERSION
#if defined CUDA_VERSION
...
@@ -200,12 +216,13 @@ void cutlass_moe_mm(
...
@@ -200,12 +216,13 @@ void cutlass_moe_mm(
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
)
{
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
,
bool
per_act_token
,
bool
per_out_ch
)
{
int32_t
version_num
=
get_sm_version_num
();
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
cutlass_moe_mm_sm90
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
);
c_strides
,
per_act_token
,
per_out_ch
);
return
;
return
;
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
...
@@ -218,14 +235,17 @@ void get_cutlass_moe_mm_data(
...
@@ -218,14 +235,17 @@ void get_cutlass_moe_mm_data(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
)
{
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
,
const
std
::
optional
<
torch
::
Tensor
>&
blockscale_offsets
)
{
// This function currently gets compiled only if we have a valid cutlass moe
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
);
output_permutation
,
num_experts
,
n
,
k
,
blockscale_offsets
);
return
;
return
;
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
...
@@ -235,6 +255,29 @@ void get_cutlass_moe_mm_data(
...
@@ -235,6 +255,29 @@ void get_cutlass_moe_mm_data(
version_num
,
". Required capability: 90"
);
version_num
,
". Required capability: 90"
);
}
}
void
get_cutlass_pplx_moe_mm_data
(
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
const
torch
::
Tensor
&
expert_num_tokens
,
const
int64_t
num_local_experts
,
const
int64_t
padded_m
,
const
int64_t
n
,
const
int64_t
k
)
{
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
get_cutlass_pplx_moe_mm_data_caller
(
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
expert_num_tokens
,
num_local_experts
,
padded_m
,
n
,
k
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: "
,
version_num
,
". Required capability: 90"
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
...
...
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
0 → 100644
View file @
4c676e3d
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using
namespace
cute
;
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementSF
,
typename
ElementAccumulator
,
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
__global__
void
__get_group_gemm_starts
(
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementSF
**
a_scales_offsets
,
ElementSF
**
b_scales_offsets
,
ElementAccumulator
**
alpha_offsets
,
LayoutSFA
*
layout_sfa_base_as_int
,
LayoutSFB
*
layout_sfb_base_as_int
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementSF
*
a_scales_base_as_int
,
ElementSF
*
b_scales_base_as_int
,
ElementAccumulator
*
alphas_base_as_int
,
const
int32_t
*
expert_offsets
,
const
int32_t
*
sf_offsets
,
const
int32_t
*
problem_sizes_as_shapes
,
const
int
K
,
const
int
N
)
{
int64_t
expert_id
=
threadIdx
.
x
;
if
(
expert_id
>=
gridDim
.
x
*
blockDim
.
x
)
{
return
;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t
expert_offset
=
static_cast
<
int64_t
>
(
expert_offsets
[
expert_id
]);
int64_t
sf_offset
=
static_cast
<
int64_t
>
(
sf_offsets
[
expert_id
]);
// size for block in block scale.
int64_t
group_size
=
16
;
int64_t
m
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
]);
int64_t
n
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
1
]);
int64_t
k
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
2
]);
assert
((
m
>=
0
&&
n
==
N
&&
k
==
K
&&
k
%
2
==
0
)
&&
"unexpected problem sizes"
);
int64_t
half_k
=
static_cast
<
int64_t
>
(
k
/
2
);
int64_t
group_k
=
static_cast
<
int64_t
>
(
k
/
group_size
);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
half_k
;
b_offsets
[
expert_id
]
=
b_base_as_int
+
expert_id
*
n
*
half_k
;
// Shape of C = [M, N]
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
sf_offset
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
a_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
expert_id
*
n
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
b_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of alpha = [E]
alpha_offsets
[
expert_id
]
=
alphas_base_as_int
+
expert_id
;
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base_as_int
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base_as_int
+
expert_id
;
*
layout_sfa_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
*
layout_sfb_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
}
template
<
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
void
run_get_group_gemm_starts
(
const
torch
::
Tensor
&
a_starts
,
const
torch
::
Tensor
&
b_starts
,
const
torch
::
Tensor
&
out_starts
,
const
torch
::
Tensor
&
a_scales_starts
,
const
torch
::
Tensor
&
b_scales_starts
,
const
torch
::
Tensor
&
alpha_starts
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
/*these are used for their base addresses*/
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
alphas
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
sf_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
int
M
,
int
N
,
int
K
)
{
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
TORCH_CHECK
(
out_tensors
.
size
(
1
)
==
N
,
"Output tensor shape doesn't match expected shape"
);
TORCH_CHECK
(
K
/
2
==
b_tensors
.
size
(
2
),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match"
);
if
(
false
)
{
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kFloat16
,
half
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
template
<
typename
OutType
>
void
run_fp4_blockwise_scaled_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
,
int
M
,
int
N
,
int
K
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int32_t
,
int32_t
,
int32_t
>>
;
using
ElementType
=
cutlass
::
float_e2m1_t
;
using
ElementSFType
=
cutlass
::
float_ue4m3_t
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementC
=
OutType
;
using
ElementD
=
ElementC
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
LayoutC
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
32
;
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
EpilogueOperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Epilogue Operator class tag
using
MainloopOperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Mainloop Operator class tag
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized based
// on the tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
struct
MMA1SMConfig
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100
;
// Kernel to launch
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
// Epilogue to launch
};
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
EpilogueOperatorClass
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
Shape
<
_128
,
_64
>
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentD
,
typename
MMA1SMConfig
::
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
MainloopOperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
typename
MMA1SMConfig
::
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm1SM
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
Gemm1SM
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
full
({
num_experts
},
output
.
stride
(
0
),
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
full
({
num_experts
},
a
.
stride
(
0
)
*
2
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
full
({
num_experts
},
b
.
stride
(
1
)
*
2
,
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Set the Scheduler info
cutlass
::
KernelHardwareInfo
hw_info
;
using
RasterOrderOptions
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm100GroupParams
<
typename
ProblemShape
::
UnderlyingProblemShape
>::
RasterOrderOptions
;
typename
Gemm
::
GemmKernel
::
TileSchedulerArguments
scheduler
;
scheduler
.
raster_order
=
RasterOrderOptions
::
AlongM
;
hw_info
.
device_id
=
a
.
get_device
();
static
std
::
unordered_map
<
int
,
int
>
cached_sm_counts
;
if
(
cached_sm_counts
.
find
(
hw_info
.
device_id
)
==
cached_sm_counts
.
end
())
{
cached_sm_counts
[
hw_info
.
device_id
]
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
hw_info
.
sm_count
=
min
(
cached_sm_counts
[
hw_info
.
device_id
],
INT_MAX
);
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
a_strides1
.
data_ptr
()),
static_cast
<
const
ElementType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
b_strides1
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
())};
auto
&
fusion_args
=
epilogue_args
.
thread
;
fusion_args
.
alpha_ptr_array
=
reinterpret_cast
<
float
**>
(
alpha_ptrs
.
data_ptr
());
fusion_args
.
dAlpha
=
{
_0
{},
_0
{},
1
};
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
// Input validation
CHECK_INPUT
(
a
,
FLOAT4_E2M1X2
,
"a"
);
CHECK_INPUT
(
b
,
FLOAT4_E2M1X2
,
"b"
);
CHECK_INPUT
(
a_blockscale
,
SF_DTYPE
,
"a_blockscale"
);
CHECK_INPUT
(
b_blockscales
,
SF_DTYPE
,
"b_blockscales"
);
CHECK_INPUT
(
alphas
,
at
::
ScalarType
::
Float
,
"alphas"
);
TORCH_CHECK
(
a_blockscale
.
dim
()
==
2
,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: "
,
a_blockscale
.
dim
())
TORCH_CHECK
(
b_blockscales
.
dim
()
==
3
,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: "
,
b_blockscales
.
dim
())
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be a 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have the shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32."
);
int
M
=
static_cast
<
int
>
(
a
.
size
(
0
));
int
N
=
static_cast
<
int
>
(
b
.
size
(
1
));
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
bfloat16_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
else
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above."
);
#endif
}
csrc/quantization/fp4/nvfp4_experts_quant.cu
0 → 100644
View file @
4c676e3d
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts.
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
if
(
rowIdx
>=
input_offset_by_experts
[
i
]
&&
rowIdx
<
input_offset_by_experts
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
input_offset_by_experts
[
i
];
expert_idx
=
i
;
break
;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
void
quant_impl
(
void
*
output
,
void
*
output_scale
,
void
*
input
,
void
*
input_global_scale
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
int
m_topk
,
int
k
,
int
n_experts
,
cudaStream_t
stream
)
{
// TODO: this multiProcessorCount should be cached.
int
device
;
cudaGetDevice
(
&
device
);
int
multiProcessorCount
;
cudaDeviceGetAttribute
(
&
multiProcessorCount
,
cudaDevAttrMultiProcessorCount
,
device
);
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
k
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m_topk
),
multiProcessorCount
*
numBlocksPerSM
));
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr
auto
HALF
=
at
::
ScalarType
::
Half
;
constexpr
auto
BF16
=
at
::
ScalarType
::
BFloat16
;
constexpr
auto
FLOAT
=
at
::
ScalarType
::
Float
;
constexpr
auto
INT
=
at
::
ScalarType
::
Int
;
constexpr
auto
UINT8
=
at
::
ScalarType
::
Byte
;
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts must be a CUDA tensor"
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK
(
output
.
scalar_type
()
==
UINT8
);
TORCH_CHECK
(
output_scale
.
scalar_type
()
==
INT
);
const
int
BLOCK_SIZE
=
16
;
auto
m_topk
=
input
.
size
(
0
);
auto
k
=
input
.
size
(
1
);
TORCH_CHECK
(
k
%
BLOCK_SIZE
==
0
,
"k must be a multiple of 16"
);
auto
n_experts
=
input_global_scale
.
size
(
0
);
TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
int
scales_k
=
k
/
BLOCK_SIZE
;
// 4 means the swizzle requirement by nvidia nvfp4.
int
padded_k
=
(
scales_k
+
(
4
-
1
))
/
4
*
4
;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
quant_impl
<
half
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
quant_impl
<
__nv_bfloat16
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
}
\ No newline at end of file
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
4c676e3d
...
@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
...
@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch
::
Tensor
const
&
input_sf
);
torch
::
Tensor
const
&
input_sf
);
#endif
#endif
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
}
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_experts_quant_sm100a
(
output
,
output_scale
,
input
,
input_global_scale
,
input_offset_by_experts
,
output_scale_offset_by_experts
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
}
csrc/quantization/fp8/common.cu
View file @
4c676e3d
...
@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type
*
__restrict__
token_output
=
&
out
[
offset
];
fp8_type
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// For vectorization, token_input and token_output pointers need to be
// aligned at
8
-byte and
4
-byte addresses respectively.
// aligned at
32
-byte and
16
-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
bool
const
can_vectorize
=
hidden_size
%
16
==
0
;
float
absmax_val
=
0.0
f
;
float
absmax_val
=
0.0
f
;
if
(
can_vectorize
)
{
if
(
can_vectorize
)
{
...
@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}
else
{
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
const
x
=
static_cast
<
float
>
(
token_input
[
i
]);
float
const
x
=
static_cast
<
float
>
(
token_input
[
i
]);
absmax_val
=
max
(
absmax_val
,
fabs
(
x
));
absmax_val
=
f
max
f
(
absmax_val
,
fabs
f
(
x
));
}
}
}
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
float
const
block_absmax_val_maybe
=
float
const
block_absmax_val_maybe
=
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
token_scale
;
__shared__
float
token_scale
;
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
if
(
scale_ub
)
{
if
(
scale_ub
)
{
token_scale
=
min
(
block_absmax_val_maybe
,
*
scale_ub
);
token_scale
=
f
min
f
(
block_absmax_val_maybe
,
*
scale_ub
);
}
else
{
}
else
{
token_scale
=
block_absmax_val_maybe
;
token_scale
=
block_absmax_val_maybe
;
}
}
// token scale computation
// token scale computation
token_scale
=
max
(
token_scale
/
quant_type_max_v
<
fp8_type
>
,
token_scale
=
f
max
f
(
token_scale
/
quant_type_max_v
<
fp8_type
>
,
min_scaling_factor
<
fp8_type
>::
val
());
min_scaling_factor
<
fp8_type
>::
val
());
scale
[
token_idx
]
=
token_scale
;
scale
[
token_idx
]
=
token_scale
;
}
}
__syncthreads
();
__syncthreads
();
...
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
scale
)
// [1]
torch
::
Tensor
const
&
scale
)
// [1]
{
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
const
block_size
=
256
;
int64_t
num_elems
=
input
.
numel
();
int
const
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
int
const
num_elems
=
input
.
numel
();
dim3
block
(
1024
);
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
block_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
...
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
torch
::
Tensor
&
scale
)
// [1]
{
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
const
block_size
=
256
;
int64_t
num_elems
=
input
.
numel
();
int
const
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
int
const
num_elems
=
input
.
numel
();
dim3
block
(
1024
);
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
block_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
...
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
...
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
const
block_size
=
256
;
dim3
const
grid
(
num_tokens
);
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
const
block
(
std
::
min
(
hidden_size
,
block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment