Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a985548
Commit
7a985548
authored
May 22, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.0' into v0.9.0-ori
parents
45d3785c
dc1440cf
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2086 additions
and
149 deletions
+2086
-149
csrc/moe/permute_unpermute_kernels/dispatch.h
csrc/moe/permute_unpermute_kernels/dispatch.h
+53
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
...permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+229
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+95
-0
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+211
-0
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+45
-18
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+24
-1
csrc/ops.h
csrc/ops.h
+48
-4
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+80
-52
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+121
-0
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+14
-2
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
...ization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
+27
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
+205
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
+75
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+5
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
+5
-17
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
+5
-46
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+15
-8
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+402
-0
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+404
-0
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+23
-1
No files found.
Too many changes to show.
To preserve performance only
486 of 486+
files are displayed.
Plain diff
Email patch
csrc/moe/permute_unpermute_kernels/dispatch.h
0 → 100644
View file @
7a985548
#pragma once
#include <cuda_fp8.h>
#define MOE_SWITCH(TYPE, ...) \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
}
#define MOE_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
using scalar_t = ScalarType2CudaType<enum_type>::type; \
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
template
<
at
::
ScalarType
type
>
struct
ScalarType2CudaType
;
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float
>
{
using
type
=
float
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Half
>
{
using
type
=
half
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
BFloat16
>
{
using
type
=
__nv_bfloat16
;
};
// #if __CUDA_ARCH__ >= 890
// fp8
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float8_e5m2
>
{
using
type
=
__nv_fp8_e5m2
;
};
template
<
>
struct
ScalarType2CudaType
<
at
::
ScalarType
::
Float8_e4m3fn
>
{
using
type
=
__nv_fp8_e4m3
;
};
// #endif
\ No newline at end of file
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
0 → 100644
View file @
7a985548
#include "moe_permute_unpermute_kernel.h"
// CubKeyValueSorter definition begin
CubKeyValueSorter
::
CubKeyValueSorter
()
:
num_experts_
(
0
),
num_bits_
(
sizeof
(
int
)
*
8
)
{}
int
CubKeyValueSorter
::
expertsToBits
(
int
num_experts
)
{
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
return
static_cast
<
int
>
(
log2
(
2
*
num_experts
-
1
))
+
1
;
}
CubKeyValueSorter
::
CubKeyValueSorter
(
int
const
num_experts
)
:
num_experts_
(
num_experts
),
num_bits_
(
expertsToBits
(
num_experts
))
{}
void
CubKeyValueSorter
::
updateNumExperts
(
int
const
num_experts
)
{
num_experts_
=
num_experts
;
num_bits_
=
expertsToBits
(
num_experts
);
}
size_t
CubKeyValueSorter
::
getWorkspaceSize
(
size_t
const
num_key_value_pairs
,
int
const
num_experts
)
{
int
num_bits
=
expertsToBits
(
num_experts
);
size_t
required_storage
=
0
;
int
*
null_int
=
nullptr
;
cub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
required_storage
,
null_int
,
null_int
,
null_int
,
null_int
,
num_key_value_pairs
,
0
,
num_bits
);
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
// inputs
if
(
required_storage
==
0
)
{
required_storage
=
1
;
}
return
required_storage
;
}
void
CubKeyValueSorter
::
run
(
void
*
workspace
,
size_t
const
workspace_size
,
int
const
*
keys_in
,
int
*
keys_out
,
int
const
*
values_in
,
int
*
values_out
,
size_t
const
num_key_value_pairs
,
cudaStream_t
stream
)
{
size_t
expected_ws_size
=
getWorkspaceSize
(
num_key_value_pairs
,
num_experts_
);
size_t
actual_ws_size
=
workspace_size
;
TORCH_CHECK
(
expected_ws_size
<=
workspace_size
,
"[CubKeyValueSorter::run] The allocated workspace is too small "
"to run this problem."
);
cub
::
DeviceRadixSort
::
SortPairs
(
workspace
,
actual_ws_size
,
keys_in
,
keys_out
,
values_in
,
values_out
,
num_key_value_pairs
,
0
,
num_bits_
,
stream
);
}
// CubKeyValueSorter definition end
static
inline
size_t
pad_to_multiple_of_16
(
size_t
const
&
input
)
{
static
constexpr
int
ALIGNMENT
=
16
;
return
ALIGNMENT
*
((
input
+
ALIGNMENT
-
1
)
/
ALIGNMENT
);
}
template
<
class
T
>
__device__
inline
int64_t
findTotalEltsLessThanTarget
(
T
const
*
sorted_indices
,
int64_t
const
arr_length
,
T
const
target
)
{
int64_t
low
=
0
,
high
=
arr_length
-
1
,
target_location
=
-
1
;
while
(
low
<=
high
)
{
int64_t
mid
=
(
low
+
high
)
/
2
;
if
(
sorted_indices
[
mid
]
>=
target
)
{
high
=
mid
-
1
;
}
else
{
low
=
mid
+
1
;
target_location
=
mid
;
}
}
return
target_location
+
1
;
}
// Calculates the start offset of the tokens for a given expert. The last
// element is the total number of valid tokens
__global__
void
computeExpertFirstTokenOffsetKernel
(
int
const
*
sorted_experts
,
int64_t
const
sorted_experts_len
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
)
{
// First, compute the global tid. We only need 1 thread per expert.
int
const
expert
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// Note that expert goes [0, num_experts] (inclusive) because we want a count
// for the total number of active tokens at the end of the scan.
if
(
expert
>=
num_experts
+
1
)
{
return
;
}
expert_first_token_offset
[
expert
]
=
findTotalEltsLessThanTarget
(
sorted_experts
,
sorted_experts_len
,
expert
);
}
void
computeExpertFirstTokenOffset
(
int
const
*
sorted_indices
,
int
const
total_indices
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
,
cudaStream_t
stream
)
{
int
const
num_entries
=
num_experts
+
1
;
int
const
threads
=
std
::
min
(
1024
,
num_entries
);
int
const
blocks
=
(
num_entries
+
threads
-
1
)
/
threads
;
computeExpertFirstTokenOffsetKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
sorted_indices
,
total_indices
,
num_experts
,
expert_first_token_offset
);
}
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
CubKeyValueSorter
&
sorter
,
void
*
sorter_ws
,
cudaStream_t
stream
)
{
int64_t
const
expanded_num_rows
=
static_cast
<
int64_t
>
(
k
)
*
num_rows
;
// We need to use the full num_experts because that is the sentinel value used
// by topk for disabled experts
sorter
.
updateNumExperts
(
num_experts
);
size_t
const
sorter_ws_size_bytes
=
pad_to_multiple_of_16
(
sorter
.
getWorkspaceSize
(
expanded_num_rows
,
num_experts
));
sorter
.
run
((
void
*
)
sorter_ws
,
sorter_ws_size_bytes
,
expert_for_source_row
,
permuted_experts
,
source_rows
,
permuted_rows
,
expanded_num_rows
,
stream
);
computeExpertFirstTokenOffset
(
permuted_experts
,
expanded_num_rows
,
num_experts_per_node
,
expert_first_token_offset
,
stream
);
}
__global__
void
preprocessTopkIdKernel
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
)
{
auto
tidx
=
threadIdx
.
x
;
auto
bidx
=
blockIdx
.
x
;
auto
lidx
=
tidx
&
31
;
auto
widx
=
tidx
>>
5
;
auto
warp_count
=
(
blockDim
.
x
+
31
)
>>
5
;
auto
offset
=
bidx
*
blockDim
.
x
;
auto
bound
=
min
(
offset
+
blockDim
.
x
,
size
);
extern
__shared__
int
smem_expert_map
[];
// store expert_map in smem
for
(
int
i
=
tidx
;
i
<
num_experts
;
i
+=
blockDim
.
x
)
{
smem_expert_map
[
i
]
=
expert_map_ptr
[
i
];
}
__syncthreads
();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if
(
offset
+
tidx
<
bound
)
{
auto
topk_id
=
topk_id_ptr
[
offset
+
tidx
];
auto
local_expert_idx
=
smem_expert_map
[
topk_id
];
if
(
local_expert_idx
==
-
1
)
{
topk_id
+=
num_experts
;
}
else
{
topk_id
=
local_expert_idx
;
}
__syncwarp
();
topk_id_ptr
[
offset
+
tidx
]
=
topk_id
;
}
}
void
preprocessTopkIdLauncher
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
)
{
int
block
=
std
::
min
(
size
,
1024
);
int
grid
=
(
size
+
block
-
1
)
/
block
;
int
smem_size
=
(
num_experts
)
*
sizeof
(
int
);
preprocessTopkIdKernel
<<<
grid
,
block
,
smem_size
,
stream
>>>
(
topk_id_ptr
,
size
,
expert_map_ptr
,
num_experts
);
}
template
<
bool
ALIGN_BLOCK_SIZE
>
__global__
void
getMIndicesKernel
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
const
int
num_local_expert
,
const
int
align_block_size
)
{
int
eidx
=
blockIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
extern
__shared__
int64_t
smem_expert_first_token_offset
[];
for
(
int
i
=
tidx
;
i
<=
num_local_expert
;
i
+=
blockDim
.
x
)
{
smem_expert_first_token_offset
[
tidx
]
=
__ldg
(
expert_first_token_offset
+
i
);
}
__syncthreads
();
auto
last_token_offset
=
smem_expert_first_token_offset
[
eidx
+
1
];
auto
first_token_offset
=
smem_expert_first_token_offset
[
eidx
];
int
n_token_in_expert
=
last_token_offset
-
first_token_offset
;
if
constexpr
(
ALIGN_BLOCK_SIZE
)
{
n_token_in_expert
=
(
n_token_in_expert
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
// round up to ALIGN_BLOCK_SIZE
int64_t
accumulate_align_offset
=
0
;
for
(
int
i
=
1
;
i
<=
eidx
+
1
;
i
++
)
{
int
n_token
=
smem_expert_first_token_offset
[
i
]
-
smem_expert_first_token_offset
[
i
-
1
];
accumulate_align_offset
=
accumulate_align_offset
+
(
n_token
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
if
(
i
==
eidx
)
{
first_token_offset
=
accumulate_align_offset
;
}
// last block store align_expert_first_token_offset
if
(
eidx
==
num_local_expert
-
1
&&
threadIdx
.
x
==
0
)
{
align_expert_first_token_offset
[
i
]
=
accumulate_align_offset
;
}
}
}
for
(
int
idx
=
tidx
;
idx
<
n_token_in_expert
;
idx
+=
blockDim
.
x
)
{
// update m_indice with expert id
m_indices
[
first_token_offset
+
idx
]
=
eidx
;
}
}
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
)
{
int
block
=
256
;
int
grid
=
num_local_expert
;
int
smem_size
=
sizeof
(
int64_t
)
*
(
num_local_expert
+
1
);
if
(
align_block_size
==
-
1
)
{
getMIndicesKernel
<
false
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
else
{
getMIndicesKernel
<
true
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
}
\ No newline at end of file
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
0 → 100644
View file @
7a985548
#pragma once
// reference from tensorrt_llm moe kernel implementation archive in
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include "dispatch.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include "cutlass/numeric_size.h"
#include "cutlass/array.h"
template
<
typename
T
>
inline
T
*
get_ptr
(
torch
::
Tensor
&
t
)
{
return
reinterpret_cast
<
T
*>
(
t
.
data_ptr
());
}
template
<
typename
T
>
inline
const
T
*
get_ptr
(
const
torch
::
Tensor
&
t
)
{
return
reinterpret_cast
<
const
T
*>
(
t
.
data_ptr
());
}
class
CubKeyValueSorter
{
public:
CubKeyValueSorter
();
CubKeyValueSorter
(
int
const
num_experts
);
void
updateNumExperts
(
int
const
num_experts
);
static
size_t
getWorkspaceSize
(
size_t
const
num_key_value_pairs
,
int
const
num_experts
);
void
run
(
void
*
workspace
,
size_t
const
workspace_size
,
int
const
*
keys_in
,
int
*
keys_out
,
int
const
*
values_in
,
int
*
values_out
,
size_t
const
num_key_value_pairs
,
cudaStream_t
stream
);
private:
static
int
expertsToBits
(
int
experts
);
int
num_experts_
;
int
num_bits_
;
};
void
computeExpertFirstTokenOffset
(
int
const
*
sorted_indices
,
int
const
total_indices
,
int
const
num_experts
,
int64_t
*
expert_first_token_offset
,
cudaStream_t
stream
);
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
CubKeyValueSorter
&
sorter
,
void
*
sorter_ws
,
cudaStream_t
stream
);
template
<
typename
T
>
void
expandInputRowsKernelLauncher
(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
const
float
*
unpermuted_scales
,
int
*
sorted_experts
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int64_t
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template
<
typename
T
,
typename
OutputType
,
bool
CHECK_SKIPPED
>
__global__
void
finalizeMoeRoutingKernel
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
orig_cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
);
template
<
class
T
,
class
OutputType
>
void
finalizeMoeRoutingKernelLauncher
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
num_rows
,
int64_t
const
cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
,
cudaStream_t
stream
);
void
preprocessTopkIdLauncher
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
);
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
);
#include "moe_permute_unpermute_kernel.inl"
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
0 → 100644
View file @
7a985548
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
int64_t align_expanded_row_accumulate = 0;
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
for (int idx = threadIdx.x; idx < num_local_experts + 1;
idx += blockDim.x) {
smem_expert_first_token_offset[idx] =
__ldg(expert_first_token_offset + idx);
}
__syncthreads();
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_k_rank = expanded_source_row / num_rows;
int64_t const source_row = expanded_source_row % num_rows;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
align_block_size);
}
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) {
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++) {
u[i] = static_cast<Type>(input[i]);
}
return u;
}
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
cutlass::sizeof_bits<T>::value);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
auto const* expanded_permuted_rows_v =
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
ComputeElem thread_output;
thread_output.fill(0);
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
auto const* expanded_permuted_rows_row_ptr =
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
int64_t const expert_idx = expert_for_source_row[k_offset];
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result);
}
OutputElem output_elem =
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
&finalizeMoeRoutingKernel<T, OutputType, true>};
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
num_valid_ptr);
}
csrc/moe/topk_softmax_kernels.cu
View file @
7a985548
...
...
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
}
}
template
<
int
TPB
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
template
<
int
TPB
,
typename
IndType
>
__launch_bounds__
(
TPB
)
__global__
void
moeTopK
(
const
float
*
inputs_after_softmax
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_rows
,
const
int
num_experts
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
using
cub_kvp
=
cub
::
KeyValuePair
<
int
,
float
>
;
...
...
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
2) This implementation assumes k is small, but will work for any k.
*/
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
>
template
<
int
VPT
,
int
NUM_EXPERTS
,
int
WARPS_PER_CTA
,
int
BYTES_PER_LDG
,
typename
IndType
>
__launch_bounds__
(
WARPS_PER_CTA
*
WARP_SIZE
)
__global__
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
int
*
indices
,
void
topkGatingSoftmax
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
const
int
num_rows
,
IndType
*
indices
,
int
*
source_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
...
...
@@ -397,8 +405,8 @@ struct TopkConstants
};
}
// namespace detail
template
<
int
EXPERTS
,
int
WARPS_PER_TB
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
int
*
indices
,
template
<
int
EXPERTS
,
int
WARPS_PER_TB
,
typename
IndType
>
void
topkGatingSoftmaxLauncherHelper
(
const
float
*
input
,
const
bool
*
finished
,
float
*
output
,
IndType
*
indices
,
int
*
source_row
,
const
int
num_rows
,
const
int
k
,
const
int
start_expert
,
const
int
end_expert
,
cudaStream_t
stream
)
{
static
constexpr
std
::
size_t
MAX_BYTES_PER_LDG
=
16
;
...
...
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
template
<
typename
IndType
>
void
topkGatingSoftmaxKernelLauncher
(
const
float
*
gating_output
,
float
*
topk_weights
,
int
*
topk_indicies
,
IndType
*
topk_indicies
,
int
*
token_expert_indices
,
float
*
softmax_workspace
,
const
int
num_tokens
,
...
...
@@ -493,14 +502,32 @@ void topk_softmax(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
gating_output
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
softmax_workspace
=
torch
::
empty
({
workspace_size
},
gating_output
.
options
());
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
if
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
)
{
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
int
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
else
{
assert
(
topk_indices
.
scalar_type
()
==
at
::
ScalarType
::
UInt32
);
vllm
::
moe
::
topkGatingSoftmaxKernelLauncher
(
gating_output
.
data_ptr
<
float
>
(),
topk_weights
.
data_ptr
<
float
>
(),
topk_indices
.
data_ptr
<
uint32_t
>
(),
token_expert_indices
.
data_ptr
<
int
>
(),
softmax_workspace
.
data_ptr
<
float
>
(),
num_tokens
,
num_experts
,
topk
,
stream
);
}
}
csrc/moe/torch_bindings.cpp
View file @
7a985548
...
...
@@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
...
...
@@ -53,7 +54,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"
);
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
m
.
def
(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
"m_indices)->()"
);
m
.
def
(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()"
);
// conditionally compiled so impl registration is in source file
#endif
...
...
csrc/ops.h
View file @
7a985548
...
...
@@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
);
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
);
void
convert_vertical_slash_indexes_mergehead
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
torch
::
Tensor
vertical_indices_count
,
// [N_HEADS, ]
torch
::
Tensor
slash_indices_count
,
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
);
#endif
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
...
...
@@ -86,17 +111,20 @@ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
// std::optional<torch::Tensor> residual);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
rot_dim
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
@@ -177,6 +205,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
torch
::
Tensor
ggml_moe_a8_vec
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
topk_ids
,
int64_t
top_k
,
int64_t
type
,
int64_t
row
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
#ifndef USE_ROCM
...
...
@@ -203,6 +235,12 @@ void cutlass_moe_mm(
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
);
void
get_cutlass_moe_mm_data
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
...
...
@@ -230,6 +268,12 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/pos_encoding_kernels.cu
View file @
7a985548
...
...
@@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int64_t
head_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
...
@@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding(
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_stride
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
if
(
key
!=
nullptr
)
{
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_stride
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
}
...
...
@@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
...
...
@@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
...
...
@@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
// or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
...
...
@@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
// namespace vllm
...
...
@@ -127,10 +136,12 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std
::
optional
<
torch
::
Tensor
>
key
,
// null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
...
...
@@ -138,40 +149,46 @@ void rotary_embedding(
int64_t
num_tokens
=
positions
.
numel
();
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
"query, key and positions must have the same number of tokens"
);
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
}
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have consistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
@@ -181,15 +198,16 @@ void rotary_embedding(
if
(
is_neox
)
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
...
...
@@ -204,10 +222,12 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std
::
optional
<
torch
::
Tensor
>
key
,
// null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int64_t
rot_dim
,
...
...
@@ -221,38 +241,44 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets"
);
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
"query, key and positions must have the same number of tokens"
);
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
}
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
}
// Make sure head_size is valid for query and key
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have concistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
@@ -263,16 +289,18 @@ void batched_rotary_embedding(
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
csrc/quantization/activation_kernels.cu
0 → 100644
View file @
7a985548
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
silu_kernel
(
const
T
&
x
)
{
// x * sigmoid(x)
return
(
T
)(((
float
)
x
)
/
(
1.0
f
+
expf
((
float
)
-
x
)));
}
// Activation and gating kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
typename
fp8_type
>
__global__
void
act_and_mul_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
float
*
scale
,
const
int
d
)
{
const
int32_t
blocks_per_token
=
gridDim
.
y
;
const
int32_t
elems_per_128bit_load
=
(
128
/
8
)
/
sizeof
(
scalar_t
);
// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const
int32_t
tgt_elems_per_block
=
div_ceil
(
d
,
blocks_per_token
);
const
int32_t
elems_per_block
=
round_to_next_multiple_of
(
tgt_elems_per_block
,
elems_per_128bit_load
);
const
int32_t
block_start
=
blockIdx
.
y
*
elems_per_block
;
int32_t
block_end
=
block_start
+
elems_per_block
;
block_end
=
block_end
>
d
?
d
:
block_end
;
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const
int64_t
token_idx
=
blockIdx
.
x
;
const
scalar_t
*
__restrict__
x_ptr
=
input
+
token_idx
*
2
*
d
;
const
scalar_t
*
__restrict__
y_ptr
=
input
+
token_idx
*
2
*
d
+
d
;
fp8_type
*
__restrict__
out_ptr
=
out
+
token_idx
*
d
;
// 128-bit vectorized code
const
int32_t
vec_loop_end
=
round_to_previous_multiple_of
(
elems_per_128bit_load
,
block_end
);
const
int32_t
vec_end_idx
=
vec_loop_end
/
elems_per_128bit_load
;
const
int32_t
vec_start_idx
=
block_start
/
elems_per_128bit_load
;
const
int4
*
__restrict__
x_128bit_ptr
=
reinterpret_cast
<
const
int4
*>
(
x_ptr
);
const
int4
*
__restrict__
y_128bit_ptr
=
reinterpret_cast
<
const
int4
*>
(
y_ptr
);
int2
*
__restrict__
out_128bit_ptr
=
reinterpret_cast
<
int2
*>
(
out_ptr
);
float
inverted_scale
=
1
/
*
scale
;
#pragma unroll
for
(
int32_t
vec_idx
=
vec_start_idx
+
threadIdx
.
x
;
vec_idx
<
vec_end_idx
;
vec_idx
+=
blockDim
.
x
)
{
const
int4
x_128bit
=
VLLM_LDG
(
&
x_128bit_ptr
[
vec_idx
]);
const
int4
y_128bit
=
VLLM_LDG
(
&
y_128bit_ptr
[
vec_idx
]);
using
scalar_128bit_vec_t
=
std
::
array
<
scalar_t
,
elems_per_128bit_load
>
;
using
scalar_64bit_vec_t
=
std
::
array
<
fp8_type
,
elems_per_128bit_load
>
;
scalar_64bit_vec_t
out_vec
;
const
auto
x_vec
=
reinterpret_cast
<
scalar_128bit_vec_t
const
&>
(
x_128bit
);
const
auto
y_vec
=
reinterpret_cast
<
scalar_128bit_vec_t
const
&>
(
y_128bit
);
#pragma unroll
for
(
int
i
=
0
;
i
<
elems_per_128bit_load
;
i
++
)
{
out_vec
[
i
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
ACT_FN
(
x_vec
[
i
])
*
y_vec
[
i
],
inverted_scale
);
}
out_128bit_ptr
[
vec_idx
]
=
reinterpret_cast
<
const
int2
&>
(
out_vec
);
}
// Scalar cleanup code
if
(
block_end
>
vec_loop_end
)
{
for
(
int64_t
idx
=
vec_loop_end
+
threadIdx
.
x
;
idx
<
block_end
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
x_ptr
[
idx
]);
const
scalar_t
y
=
VLLM_LDG
(
&
y_ptr
[
idx
]);
out_ptr
[
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
ACT_FN
(
x
)
*
y
,
inverted_scale
);
}
}
}
}
// namespace vllm
// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
fp8_t> \
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), d); \
}); \
});
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., 2 * d]
torch
::
Tensor
&
scale
)
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
out
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
2
==
0
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
silu_kernel
);
}
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
7a985548
...
...
@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float
dst
=
std
::
nearbyint
(
x
);
// saturate
dst
=
std
::
clamp
(
dst
,
i8_min
,
i8_max
);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
dst
=
(
dst
<
i8_min
)
?
i8_min
:
(
dst
>
i8_max
)
?
i8_max
:
dst
;
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
...
...
@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast
<
int32_t
>
(
std
::
numeric_limits
<
int8_t
>::
max
());
// saturate
int32_t
dst
=
std
::
clamp
(
x
,
i8_min
,
i8_max
);
// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// int32_t dst = std::clamp(x, i8_min, i8_max);
int32_t
dst
=
(
x
<
i8_min
)
?
i8_min
:
(
x
>
i8_max
)
?
i8_max
:
x
;
return
static_cast
<
int8_t
>
(
dst
);
#else
// CUDA path
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
0 → 100644
View file @
7a985548
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
size
(
0
)
%
4
==
0
,
"Input tensor must have a number of rows that is a multiple of 4. "
,
"but got: "
,
a
.
size
(
0
),
" rows."
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
0 → 100644
View file @
7a985548
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
template
<
typename
OutType
,
typename
MmaTileShape
,
typename
ScalesPerTile
,
class
ClusterShape
,
typename
EpilogueScheduler
,
typename
MainloopScheduler
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementC
=
void
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
LayoutC
=
LayoutD
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static
constexpr
int
ScaleMsPerTile
=
size
<
0
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
MmaTileShape
{})
/
ScaleMsPerTile
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
MmaTileShape
{})
/
size
<
1
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityK
=
size
<
2
>
(
MmaTileShape
{})
/
size
<
2
>
(
ScalesPerTile
{});
// Shape of the threadblocks in a cluster
using
ClusterShape_MNK
=
ClusterShape
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
ElementScalar
=
float
;
// clang-format off
using
DefaultOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementCompute
,
ElementC
,
ElementScalar
,
RoundStyle
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
MmaTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueScheduler
,
DefaultOperation
>::
CollectiveOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm100_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutSFA
=
typename
Gemm
::
LayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
LayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
ScaleConfig
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
StrideA
a_stride
;
StrideB
b_stride
;
StrideC
c_stride
;
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm100_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
m
=
a
.
size
(
0
);
auto
k
=
a
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
a
.
get_device
());
auto
should_use_2sm
=
[
&
sms
](
int
m
,
int
n
,
int
tile1SM
=
128
)
{
return
std
::
ceil
(
static_cast
<
float
>
(
m
)
/
tile1SM
)
*
std
::
ceil
(
static_cast
<
float
>
(
n
)
/
tile1SM
)
>=
sms
;
};
bool
use_2sm
=
should_use_2sm
(
m
,
n
);
if
(
use_2sm
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
Shape
<
_256
,
_128
,
_128
>
,
Shape
<
_256
,
_1
,
_1
>
,
Shape
<
_2
,
_2
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_128
,
_1
,
_1
>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
0 → 100644
View file @
7a985548
#include <torch/all.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template
<
typename
Fp8Func
,
typename
Int8Func
,
typename
BlockwiseFunc
>
void
dispatch_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
,
Fp8Func
fp8_func
,
Int8Func
int8_func
,
BlockwiseFunc
blockwise_func
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
fp8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
if
constexpr
(
!
std
::
is_same_v
<
Int8Func
,
std
::
nullptr_t
>
)
{
int8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
false
,
"Int8 not supported for this architecture"
);
}
}
}
else
{
TORCH_CHECK
(
a_scales
.
dim
()
==
2
,
"a scale must be 2d tensor."
);
TORCH_CHECK
(
b_scales
.
dim
()
==
2
,
"b scale must be 2d tensor."
);
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
100
)
{
TORCH_CHECK
(
a
.
size
(
0
)
==
a_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
a
.
size
(
1
),
int64_t
(
128
))
==
a_scales
.
size
(
1
),
"a_scale_group_shape must be [1, 128]."
);
TORCH_CHECK
(
cuda_utils
::
ceil_div
(
b
.
size
(
0
),
int64_t
(
128
))
==
b_scales
.
size
(
0
)
&&
cuda_utils
::
ceil_div
(
b
.
size
(
1
),
int64_t
(
128
))
==
b_scales
.
size
(
1
),
"b_scale_group_shape must be [128, 128]."
);
}
else
{
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
// kernel, or introducing ceil_div to the load_init() of mainloop.
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
}
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
blockwise_func
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
View file @
7a985548
...
...
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
View file @
7a985548
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
...
...
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
TORCH_CHECK
(
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell"
);
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently, only fp8 gemm is implemented for Blackwell"
);
vllm
::
cutlass_scaled_mm_sm100_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
vllm
::
cutlass_scaled_mm_sm100_fp8
,
nullptr
,
// int8 not supported on SM100
vllm
::
cutlass_scaled_mm_blockwise_sm100_fp8
);
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
View file @
7a985548
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
...
...
@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
{
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
vllm
::
cutlass_scaled_mm_sm90_fp8
,
vllm
::
cutlass_scaled_mm_sm90_int8
,
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
);
}
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
7a985548
...
...
@@ -29,7 +29,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void
cutlass_moe_mm_sm90
(
torch
::
Tensor
&
out_tensors
,
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -37,12 +38,6 @@ void cutlass_moe_mm_sm90(
torch
::
Tensor
const
&
problem_sizes
,
torch
::
Tensor
const
&
a_strides
,
torch
::
Tensor
const
&
b_strides
,
torch
::
Tensor
const
&
c_strides
);
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
...
...
@@ -53,6 +48,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
void
get_cutlass_moe_mm_data_caller
(
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
expert_offsets
,
torch
::
Tensor
&
problem_sizes1
,
torch
::
Tensor
&
problem_sizes2
,
torch
::
Tensor
&
input_permutation
,
torch
::
Tensor
&
output_permutation
,
const
int64_t
num_experts
,
const
int64_t
n
,
const
int64_t
k
);
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -110,6 +114,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
&&
cuda_device_capability
<
100
)
{
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
100
)
{
return
CUDA_VERSION
>=
12080
;
}
#endif
...
...
@@ -222,7 +228,8 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t
version_num
=
get_sm_version_num
();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
);
...
...
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
0 → 100644
View file @
7a985548
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using
namespace
cute
;
template
<
typename
ElementAB
,
typename
ElementC
,
typename
ElementSF
,
typename
ElementAccumulator
,
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
__global__
void
__get_group_gemm_starts
(
ElementAB
**
a_offsets
,
ElementAB
**
b_offsets
,
ElementC
**
out_offsets
,
ElementSF
**
a_scales_offsets
,
ElementSF
**
b_scales_offsets
,
ElementAccumulator
**
alpha_offsets
,
LayoutSFA
*
layout_sfa_base_as_int
,
LayoutSFB
*
layout_sfb_base_as_int
,
ElementAB
*
a_base_as_int
,
ElementAB
*
b_base_as_int
,
ElementC
*
out_base_as_int
,
ElementSF
*
a_scales_base_as_int
,
ElementSF
*
b_scales_base_as_int
,
ElementAccumulator
*
alphas_base_as_int
,
const
int32_t
*
expert_offsets
,
const
int32_t
*
sf_offsets
,
const
int32_t
*
problem_sizes_as_shapes
,
const
int
K
,
const
int
N
)
{
int64_t
expert_id
=
threadIdx
.
x
;
if
(
expert_id
>=
gridDim
.
x
*
blockDim
.
x
)
{
return
;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t
expert_offset
=
static_cast
<
int64_t
>
(
expert_offsets
[
expert_id
]);
int64_t
sf_offset
=
static_cast
<
int64_t
>
(
sf_offsets
[
expert_id
]);
// size for block in block scale.
int64_t
group_size
=
16
;
int64_t
m
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
]);
int64_t
n
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
1
]);
int64_t
k
=
static_cast
<
int64_t
>
(
problem_sizes_as_shapes
[
expert_id
*
3
+
2
]);
assert
((
m
>=
0
&&
n
==
N
&&
k
==
K
&&
k
%
2
==
0
)
&&
"unexpected problem sizes"
);
int64_t
half_k
=
static_cast
<
int64_t
>
(
k
/
2
);
int64_t
group_k
=
static_cast
<
int64_t
>
(
k
/
group_size
);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets
[
expert_id
]
=
a_base_as_int
+
expert_offset
*
half_k
;
b_offsets
[
expert_id
]
=
b_base_as_int
+
expert_id
*
n
*
half_k
;
// Shape of C = [M, N]
out_offsets
[
expert_id
]
=
out_base_as_int
+
expert_offset
*
n
;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets
[
expert_id
]
=
a_scales_base_as_int
+
sf_offset
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
a_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets
[
expert_id
]
=
b_scales_base_as_int
+
expert_id
*
n
*
group_k
;
assert
((
reinterpret_cast
<
uintptr_t
>
(
b_scales_offsets
[
expert_id
])
%
128
)
==
0
&&
"TMA requires 128-byte alignment"
);
// Shape of alpha = [E]
alpha_offsets
[
expert_id
]
=
alphas_base_as_int
+
expert_id
;
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base_as_int
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base_as_int
+
expert_id
;
*
layout_sfa_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
*
layout_sfb_ptr
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
static_cast
<
int
>
(
k
),
1
));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
}
template
<
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
void
run_get_group_gemm_starts
(
const
torch
::
Tensor
&
a_starts
,
const
torch
::
Tensor
&
b_starts
,
const
torch
::
Tensor
&
out_starts
,
const
torch
::
Tensor
&
a_scales_starts
,
const
torch
::
Tensor
&
b_scales_starts
,
const
torch
::
Tensor
&
alpha_starts
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
/*these are used for their base addresses*/
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
alphas
,
torch
::
Tensor
const
&
expert_offsets
,
torch
::
Tensor
const
&
sf_offsets
,
torch
::
Tensor
const
&
problem_sizes
,
int
M
,
int
N
,
int
K
)
{
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a_tensors
.
device
().
index
());
TORCH_CHECK
(
out_tensors
.
size
(
1
)
==
N
,
"Output tensor shape doesn't match expected shape"
);
TORCH_CHECK
(
K
/
2
==
b_tensors
.
size
(
2
),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match"
);
if
(
false
)
{
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kBFloat16
,
cutlass
::
bfloat16_t
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE
(
cutlass
::
float_e2m1_t
,
cutlass
::
float_ue4m3_t
,
torch
::
kFloat16
,
half
,
LayoutSFA
,
LayoutSFB
,
ScaleConfig
)
else
{
TORCH_CHECK
(
false
,
"Invalid output type (must be float16 or bfloat16)"
);
}
}
template
<
typename
OutType
>
void
run_fp4_blockwise_scaled_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
,
int
M
,
int
N
,
int
K
)
{
using
ProblemShape
=
cutlass
::
gemm
::
GroupProblemShape
<
Shape
<
int32_t
,
int32_t
,
int32_t
>>
;
using
ElementType
=
cutlass
::
float_e2m1_t
;
using
ElementSFType
=
cutlass
::
float_ue4m3_t
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementB
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementC
=
OutType
;
using
ElementD
=
ElementC
;
using
ElementAccumulator
=
float
;
// Layout definitions
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
LayoutC
;
// Alignment constraints
static
constexpr
int
AlignmentA
=
32
;
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Architecture definitions
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
EpilogueOperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Epilogue Operator class tag
using
MainloopOperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Mainloop Operator class tag
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
// Stage count maximized based
// on the tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
struct
MMA1SMConfig
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100
;
// Kernel to launch
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized1Sm
;
// Epilogue to launch
};
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
EpilogueOperatorClass
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
Shape
<
_128
,
_64
>
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutC
*
,
AlignmentC
,
ElementD
,
LayoutC
*
,
AlignmentD
,
typename
MMA1SMConfig
::
EpilogueSchedule
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
MainloopOperatorClass
,
ElementA
,
LayoutA
*
,
AlignmentA
,
ElementB
,
LayoutB
*
,
AlignmentB
,
ElementAccumulator
,
typename
MMA1SMConfig
::
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
typename
MMA1SMConfig
::
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
ProblemShape
,
CollectiveMainloop
,
CollectiveEpilogue
>
;
using
Gemm1SM
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
Gemm
=
Gemm1SM
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
InternalStrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
InternalStrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
InternalStrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
InternalStrideD
;
using
LayoutSFA
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
InternalLayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
UnderlyingProblemShape
=
ProblemShape
::
UnderlyingProblemShape
;
int
num_experts
=
static_cast
<
int
>
(
expert_offsets
.
size
(
0
));
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
a_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
out_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
a_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_scales_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
full
({
num_experts
},
output
.
stride
(
0
),
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
full
({
num_experts
},
a
.
stride
(
0
)
*
2
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
full
({
num_experts
},
b
.
stride
(
1
)
*
2
,
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape
*
problem_sizes_as_shapes
=
static_cast
<
UnderlyingProblemShape
*>
(
problem_sizes
.
data_ptr
());
// Set the Scheduler info
cutlass
::
KernelHardwareInfo
hw_info
;
using
RasterOrderOptions
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm100GroupParams
<
typename
ProblemShape
::
UnderlyingProblemShape
>::
RasterOrderOptions
;
typename
Gemm
::
GemmKernel
::
TileSchedulerArguments
scheduler
;
scheduler
.
raster_order
=
RasterOrderOptions
::
AlongM
;
hw_info
.
device_id
=
a
.
get_device
();
static
std
::
unordered_map
<
int
,
int
>
cached_sm_counts
;
if
(
cached_sm_counts
.
find
(
hw_info
.
device_id
)
==
cached_sm_counts
.
end
())
{
cached_sm_counts
[
hw_info
.
device_id
]
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
hw_info
.
sm_count
=
min
(
cached_sm_counts
[
hw_info
.
device_id
],
INT_MAX
);
// Mainloop Arguments
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
static_cast
<
const
ElementType
**>
(
a_ptrs
.
data_ptr
()),
static_cast
<
StrideA
*>
(
a_strides1
.
data_ptr
()),
static_cast
<
const
ElementType
**>
(
b_ptrs
.
data_ptr
()),
static_cast
<
StrideB
*>
(
b_strides1
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
a_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFA
*>
(
layout_sfa
.
data_ptr
()),
static_cast
<
const
ElementSFType
**>
(
b_scales_ptrs
.
data_ptr
()),
reinterpret_cast
<
LayoutSFB
*>
(
layout_sfb
.
data_ptr
())};
// Epilogue Arguments
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
// epilogue.thread
nullptr
,
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
()),
static_cast
<
ElementD
**>
(
out_ptrs
.
data_ptr
()),
static_cast
<
StrideC
*>
(
c_strides1
.
data_ptr
())};
auto
&
fusion_args
=
epilogue_args
.
thread
;
fusion_args
.
alpha_ptr_array
=
reinterpret_cast
<
float
**>
(
alpha_ptrs
.
data_ptr
());
fusion_args
.
dAlpha
=
{
_0
{},
_0
{},
1
};
// Gemm Arguments
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGrouped
,
{
num_experts
,
problem_sizes_as_shapes
,
nullptr
},
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement_status
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement_status
==
cutlass
::
Status
::
kSuccess
,
"Failed to implement GEMM"
);
// Run the GEMM
auto
status
=
gemm_op
.
initialize
(
args
,
workspace
.
data_ptr
());
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to initialize GEMM"
);
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"Failed to run GEMM"
);
}
constexpr
auto
FLOAT4_E2M1X2
=
at
::
ScalarType
::
Byte
;
constexpr
auto
SF_DTYPE
=
at
::
ScalarType
::
Float8_e4m3fn
;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void
cutlass_fp4_group_mm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
a_blockscale
,
const
torch
::
Tensor
&
b_blockscales
,
const
torch
::
Tensor
&
alphas
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
sf_offsets
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
// Input validation
CHECK_INPUT
(
a
,
FLOAT4_E2M1X2
,
"a"
);
CHECK_INPUT
(
b
,
FLOAT4_E2M1X2
,
"b"
);
CHECK_INPUT
(
a_blockscale
,
SF_DTYPE
,
"a_blockscale"
);
CHECK_INPUT
(
b_blockscales
,
SF_DTYPE
,
"b_blockscales"
);
CHECK_INPUT
(
alphas
,
at
::
ScalarType
::
Float
,
"alphas"
);
TORCH_CHECK
(
a_blockscale
.
dim
()
==
2
,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: "
,
a_blockscale
.
dim
())
TORCH_CHECK
(
b_blockscales
.
dim
()
==
3
,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: "
,
b_blockscales
.
dim
())
TORCH_CHECK
(
problem_sizes
.
dim
()
==
2
,
"problem_sizes must be a 2D tensor"
);
TORCH_CHECK
(
problem_sizes
.
size
(
1
)
==
3
,
"problem_sizes must have the shape (num_experts, 3)"
);
TORCH_CHECK
(
problem_sizes
.
size
(
0
)
==
expert_offsets
.
size
(
0
),
"Number of experts in problem_sizes must match expert_offsets"
);
TORCH_CHECK
(
problem_sizes
.
dtype
()
==
torch
::
kInt32
,
"problem_sizes must be int32."
);
int
M
=
static_cast
<
int
>
(
a
.
size
(
0
));
int
N
=
static_cast
<
int
>
(
b
.
size
(
1
));
int
E
=
static_cast
<
int
>
(
b
.
size
(
0
));
int
K
=
static_cast
<
int
>
(
2
*
b
.
size
(
2
));
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
bfloat16_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
else
{
run_fp4_blockwise_scaled_group_mm
<
cutlass
::
half_t
>
(
output
,
a
,
b
,
a_blockscale
,
b_blockscales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
M
,
N
,
K
);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above."
);
#endif
}
csrc/quantization/fp4/nvfp4_experts_quant.cu
0 → 100644
View file @
7a985548
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
]),
"f"
(
array
[
1
]),
"f"
(
array
[
2
]),
"f"
(
array
[
3
]),
"f"
(
array
[
4
]),
"f"
(
array
[
5
]),
"f"
(
array
[
6
]),
"f"
(
array
[
7
]));
return
val
;
#else
return
0
;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
val
;
asm
volatile
(
"{
\n
"
".reg .b8 byte0;
\n
"
".reg .b8 byte1;
\n
"
".reg .b8 byte2;
\n
"
".reg .b8 byte3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
#else
return
0
;
#endif
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
#endif
return
nullptr
;
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
};
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float
SFValue
=
SFScaleVal
*
(
vecMax
*
reciprocal_approximate_ftz
(
6.0
f
));
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
if
constexpr
(
UE8M0_SF
)
{
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t
tmp
=
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
>>
23
;
fp8SFVal
=
tmp
&
0xff
;
// Convert back to fp32.
reinterpret_cast
<
uint32_t
&>
(
SFValue
)
=
tmp
<<
23
;
}
else
{
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3
tmp
=
__nv_fp8_e4m3
(
SFValue
);
reinterpret_cast
<
__nv_fp8_e4m3
&>
(
fp8SFVal
)
=
tmp
;
// Convert back to fp32.
SFValue
=
float
(
tmp
);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
#else
return
0
;
#endif
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
512
,
4
)
cvt_fp16_to_fp4
(
#else
cvt_fp16_to_fp4
(
#endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
// Find index within the experts.
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
if
(
rowIdx
>=
input_offset_by_experts
[
i
]
&&
rowIdx
<
input_offset_by_experts
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
input_offset_by_experts
[
i
];
expert_idx
=
i
;
break
;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numCols_SFout
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numCols
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
}
#endif
}
template
<
typename
T
>
void
quant_impl
(
void
*
output
,
void
*
output_scale
,
void
*
input
,
void
*
input_global_scale
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
int
m_topk
,
int
k
,
int
n_experts
,
cudaStream_t
stream
)
{
// TODO: this multiProcessorCount should be cached.
int
device
;
cudaGetDevice
(
&
device
);
int
multiProcessorCount
;
cudaDeviceGetAttribute
(
&
multiProcessorCount
,
cudaDevAttrMultiProcessorCount
,
device
);
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
k
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
2048
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m_topk
),
multiProcessorCount
*
numBlocksPerSM
));
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
float
*>
(
input_global_scale
),
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr
auto
HALF
=
at
::
ScalarType
::
Half
;
constexpr
auto
BF16
=
at
::
ScalarType
::
BFloat16
;
constexpr
auto
FLOAT
=
at
::
ScalarType
::
Float
;
constexpr
auto
INT
=
at
::
ScalarType
::
Int
;
constexpr
auto
UINT8
=
at
::
ScalarType
::
Byte
;
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
CHECK_INPUT
(
output
,
"output must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale
,
"output_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input
,
"input must be a CUDA tensor"
);
CHECK_INPUT
(
input_global_scale
,
"input_global_scale must be a CUDA tensor"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts must be a CUDA tensor"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts must be a CUDA tensor"
);
TORCH_CHECK
(
output
.
dim
()
==
2
);
TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
TORCH_CHECK
(
input
.
dim
()
==
2
);
TORCH_CHECK
(
input_global_scale
.
dim
()
==
1
);
TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
TORCH_CHECK
(
input_global_scale
.
scalar_type
()
==
FLOAT
);
TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK
(
output
.
scalar_type
()
==
UINT8
);
TORCH_CHECK
(
output_scale
.
scalar_type
()
==
INT
);
const
int
BLOCK_SIZE
=
16
;
auto
m_topk
=
input
.
size
(
0
);
auto
k
=
input
.
size
(
1
);
TORCH_CHECK
(
k
%
BLOCK_SIZE
==
0
,
"k must be a multiple of 16"
);
auto
n_experts
=
input_global_scale
.
size
(
0
);
TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
int
scales_k
=
k
/
BLOCK_SIZE
;
// 4 means the swizzle requirement by nvidia nvfp4.
int
padded_k
=
(
scales_k
+
(
4
-
1
))
/
4
*
4
;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
auto
in_dtype
=
input
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
quant_impl
<
half
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
quant_impl
<
__nv_bfloat16
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_global_scale
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Expected input data type to be half or bfloat16"
);
}
}
\ No newline at end of file
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
7a985548
...
...
@@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch
::
Tensor
const
&
input_sf
);
#endif
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void
scaled_fp4_experts_quant_sm100a
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_quant_sm100a
(
output
,
input
,
output_sf
,
input_sf
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
}
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input_offset_by_experts
,
torch
::
Tensor
const
&
output_scale_offset_by_experts
)
{
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return
scaled_fp4_experts_quant_sm100a
(
output
,
output_scale
,
input
,
input_global_scale
,
input_offset_by_experts
,
output_scale_offset_by_experts
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 experts quantization kernel"
);
}
Prev
1
2
3
4
5
6
7
8
9
10
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment