Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ed7a29d9
Unverified
Commit
ed7a29d9
authored
Apr 27, 2025
by
Kaixi Hou
Committed by
GitHub
Apr 27, 2025
Browse files
[NVIDIA] Support Cutlass MLA for Blackwell GPUs (#16032)
Signed-off-by:
kaixih
<
kaixih@nvidia.com
>
parent
756848e7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
403 additions
and
5 deletions
+403
-5
CMakeLists.txt
CMakeLists.txt
+24
-4
csrc/attention/mla/cutlass_mla_entry.cu
csrc/attention/mla/cutlass_mla_entry.cu
+38
-0
csrc/attention/mla/cutlass_mla_kernels.cu
csrc/attention/mla/cutlass_mla_kernels.cu
+225
-0
csrc/ops.h
csrc/ops.h
+6
-0
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+1
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-0
tests/kernels/test_cutlass_mla_decode.py
tests/kernels/test_cutlass_mla_decode.py
+93
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+9
-0
No files found.
CMakeLists.txt
View file @
ed7a29d9
...
@@ -251,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -251,7 +251,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
# Please keep this in sync with FetchContent_Declare line below.
# Please keep this in sync with FetchContent_Declare line below.
set
(
CUTLASS_REVISION
"v3.
8
.0"
CACHE STRING
"CUTLASS revision to use"
)
set
(
CUTLASS_REVISION
"v3.
9
.0"
CACHE STRING
"CUTLASS revision to use"
)
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if
(
DEFINED ENV{VLLM_CUTLASS_SRC_DIR}
)
if
(
DEFINED ENV{VLLM_CUTLASS_SRC_DIR}
)
...
@@ -269,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -269,7 +269,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cutlass
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.
8
.0
GIT_TAG v3.
9
.0
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
...
@@ -290,7 +290,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -290,7 +290,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
)
"csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu"
)
set_gencode_flags_for_srcs
(
set_gencode_flags_for_srcs
(
SRCS
"
${
VLLM_EXT_SRC
}
"
SRCS
"
${
VLLM_EXT_SRC
}
"
...
@@ -463,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -463,7 +464,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set
(
FP4_ARCHS
)
set
(
FP4_ARCHS
)
endif
()
endif
()
#
# CUTLASS MLA Archs and flags
cuda_archs_loose_intersection
(
MLA_ARCHS
"10.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND MLA_ARCHS
)
set
(
SRCS
"csrc/attention/mla/cutlass_mla_kernels.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
MLA_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_CUTLASS_MLA=1"
)
# Add MLA-specific include directories only to MLA source files
set_source_files_properties
(
${
SRCS
}
PROPERTIES INCLUDE_DIRECTORIES
"
${
CUTLASS_DIR
}
/examples/77_blackwell_fmha;
${
CUTLASS_DIR
}
/examples/common"
)
message
(
STATUS
"Building CUTLASS MLA for archs:
${
MLA_ARCHS
}
"
)
else
()
message
(
STATUS
"Not building CUTLASS MLA as no compatible archs were found."
)
# clear MLA_ARCHS
set
(
MLA_ARCHS
)
endif
()
# CUTLASS MoE kernels
# CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
...
...
csrc/attention/mla/cutlass_mla_entry.cu
0 → 100644
View file @
ed7a29d9
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
void
cutlass_mla_decode_sm100a
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
double
scale
);
#endif
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
double
scale
)
{
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
return
cutlass_mla_decode_sm100a
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass MLA"
);
}
csrc/attention/mla/cutlass_mla_kernels.cu
0 → 100644
View file @
ed7a29d9
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass_extensions/common.hpp"
#include "device/sm100_mla.hpp"
#include "kernel/sm100_mla_tile_scheduler.hpp"
using
namespace
cute
;
using
namespace
cutlass
::
fmha
::
kernel
;
template
<
typename
T
,
bool
PersistenceOption
=
true
>
struct
MlaSm100
{
using
Element
=
T
;
using
ElementAcc
=
float
;
using
ElementOut
=
T
;
using
TileShape
=
Shape
<
_128
,
_128
,
Shape
<
_512
,
_64
>>
;
using
TileShapeH
=
cute
::
tuple_element_t
<
0
,
TileShape
>
;
using
TileShapeD
=
cute
::
tuple_element_t
<
2
,
TileShape
>
;
// H K (D_latent D_rope) B
using
ProblemShape
=
cute
::
tuple
<
TileShapeH
,
int
,
TileShapeD
,
int
>
;
using
StrideQ
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// H D B
using
StrideK
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// K D B
using
StrideO
=
StrideK
;
// H D B
using
StrideLSE
=
cute
::
tuple
<
_1
,
int
>
;
// H B
using
TileScheduler
=
std
::
conditional_t
<
PersistenceOption
,
Sm100MlaPersistentTileScheduler
,
Sm100MlaIndividualTileScheduler
>
;
using
FmhaKernel
=
cutlass
::
fmha
::
kernel
::
Sm100FmhaMlaKernelTmaWarpspecialized
<
TileShape
,
Element
,
ElementAcc
,
ElementOut
,
ElementAcc
,
TileScheduler
,
/*kIsCpAsync=*/
true
>
;
using
Fmha
=
cutlass
::
fmha
::
device
::
MLA
<
FmhaKernel
>
;
};
template
<
typename
T
>
typename
T
::
Fmha
::
Arguments
args_from_options
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
double
scale
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope
.
sizes
()[
0
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
int
max_seq_len
=
page_size
*
page_count_per_seq
;
using
TileShapeH
=
typename
T
::
TileShapeH
;
using
TileShapeD
=
typename
T
::
TileShapeD
;
auto
problem_shape
=
cute
::
make_tuple
(
TileShapeH
{},
max_seq_len
,
TileShapeD
{},
batches
);
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
using
StrideQ
=
typename
T
::
StrideQ
;
using
StrideK
=
typename
T
::
StrideK
;
using
StrideO
=
typename
T
::
StrideO
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
StrideQ
stride_Q_latent
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_latent
),
_1
{},
static_cast
<
int64_t
>
(
H
*
D_latent
));
StrideQ
stride_Q_rope
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
H
*
D_rope
));
StrideK
stride_C
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
StrideLSE
stride_LSE
=
cute
::
make_tuple
(
_1
{},
static_cast
<
int
>
(
H
));
StrideO
stride_O
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_latent
),
_1
{},
static_cast
<
int64_t
>
(
H
*
D_latent
));
using
Element
=
typename
T
::
Element
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
auto
Q_latent_ptr
=
static_cast
<
Element
*>
(
q_nope
.
data_ptr
());
auto
Q_rope_ptr
=
static_cast
<
Element
*>
(
q_pe
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
auto
scale_f
=
static_cast
<
float
>
(
scale
);
typename
T
::
Fmha
::
Arguments
arguments
{
problem_shape
,
{
scale_f
,
Q_latent_ptr
,
stride_Q_latent
,
Q_rope_ptr
,
stride_Q_rope
,
C_ptr
,
stride_C
,
C_ptr
+
D_latent
,
stride_C
,
static_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
static_cast
<
int
*>
(
page_table
.
data_ptr
()),
stride_PT
,
page_count_total
,
page_size
},
{
static_cast
<
ElementOut
*>
(
out
.
data_ptr
()),
stride_O
,
static_cast
<
ElementAcc
*>
(
nullptr
),
stride_LSE
},
hw_info
,
-
1
,
// split_kv
nullptr
,
// is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T
::
Fmha
::
set_split_kv
(
arguments
);
return
arguments
;
}
template
<
typename
Element
>
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
float
scale
,
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
);
size_t
workspace_size
=
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
q_nope
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
fmha
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
fmha
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
void
cutlass_mla_decode_sm100a
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
double
scale
)
{
TORCH_CHECK
(
q_nope
.
device
().
is_cuda
(),
"q_nope must be on CUDA"
);
TORCH_CHECK
(
q_nope
.
dim
()
==
3
,
"q_nope must be a 3D tensor"
);
TORCH_CHECK
(
q_pe
.
dim
()
==
3
,
"q_pe must be a 3D tensor"
);
TORCH_CHECK
(
kv_c_and_k_pe_cache
.
dim
()
==
3
,
"kv_c_and_k_pe_cache must be a 3D tensor"
);
TORCH_CHECK
(
seq_lens
.
dim
()
==
1
,
"seq_lens must be a 1D tensor"
);
TORCH_CHECK
(
page_table
.
dim
()
==
2
,
"page_table must be a 2D tensor"
);
TORCH_CHECK
(
out
.
dim
()
==
3
,
"out must be a 3D tensor"
);
auto
B_q_nope
=
q_nope
.
size
(
0
);
auto
H_q_nope
=
q_nope
.
size
(
1
);
auto
D_q_nope
=
q_nope
.
size
(
2
);
auto
B_q_pe
=
q_pe
.
size
(
0
);
auto
H_q_pe
=
q_pe
.
size
(
1
);
auto
D_q_pe
=
q_pe
.
size
(
2
);
auto
B_pt
=
page_table
.
size
(
0
);
auto
PAGE_NUM
=
page_table
.
size
(
1
);
auto
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
);
auto
D_ckv
=
kv_c_and_k_pe_cache
.
size
(
2
);
auto
B_o
=
out
.
size
(
0
);
auto
H_o
=
out
.
size
(
1
);
auto
D_o
=
out
.
size
(
2
);
TORCH_CHECK
(
D_q_nope
==
512
,
"D_q_nope must be equal to 512"
);
TORCH_CHECK
(
D_q_pe
==
64
,
"D_q_pe must be equal to 64"
);
TORCH_CHECK
(
D_ckv
==
576
,
"D_ckv must be equal to 576"
);
TORCH_CHECK
(
H_q_nope
==
H_q_pe
&&
H_q_nope
==
H_o
&&
H_o
==
128
,
"H_q_nope, H_q_pe, and H_o must be equal to 128"
);
TORCH_CHECK
(
PAGE_SIZE
>
0
&&
(
PAGE_SIZE
&
(
PAGE_SIZE
-
1
))
==
0
,
"PAGE_SIZE must be a power of 2"
);
TORCH_CHECK
(
B_q_nope
==
B_q_pe
&&
B_q_nope
==
B_pt
&&
B_q_nope
==
B_o
,
"Batch dims must be same for page_table, q_nope and q_pe, and out"
);
TORCH_CHECK
(
PAGE_NUM
%
(
128
/
PAGE_SIZE
)
==
0
,
"PAGE_NUM must be divisible by 128 / PAGE_SIZE"
);
TORCH_CHECK
(
D_o
==
512
,
"D_o must be equal to 512"
);
TORCH_CHECK
(
q_nope
.
dtype
()
==
at
::
ScalarType
::
Half
||
q_nope
.
dtype
()
==
at
::
ScalarType
::
BFloat16
||
q_nope
.
dtype
()
==
at
::
ScalarType
::
Float8_e4m3fn
,
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor"
);
TORCH_CHECK
(
kv_c_and_k_pe_cache
.
dtype
()
==
q_nope
.
dtype
()
&&
q_nope
.
dtype
()
==
q_pe
.
dtype
(),
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"
);
TORCH_CHECK
(
seq_lens
.
dtype
()
==
torch
::
kInt32
,
"seq_lens must be a 32-bit integer tensor"
);
TORCH_CHECK
(
page_table
.
dtype
()
==
torch
::
kInt32
,
"page_table must be a 32-bit integer tensor"
);
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
}
csrc/ops.h
View file @
ed7a29d9
...
@@ -128,6 +128,12 @@ void advance_step_flashinfer(
...
@@ -128,6 +128,12 @@ void advance_step_flashinfer(
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
double
scale
);
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
);
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
);
#ifndef USE_ROCM
#ifndef USE_ROCM
...
...
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
View file @
ed7a29d9
...
@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
...
@@ -134,7 +134,7 @@ typename T::Gemm::Arguments args_from_options(
using
StrideB
=
typename
T
::
StrideB
;
using
StrideB
=
typename
T
::
StrideB
;
using
StrideD
=
typename
T
::
StrideD
;
using
StrideD
=
typename
T
::
StrideD
;
using
Sm100BlkScaledConfig
=
using
Sm100BlkScaledConfig
=
typename
T
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1
00
BlkScaledConfig
;
typename
T
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1
xx
BlkScaledConfig
;
int
m
=
static_cast
<
int
>
(
M
);
int
m
=
static_cast
<
int
>
(
M
);
int
n
=
static_cast
<
int
>
(
N
);
int
n
=
static_cast
<
int
>
(
N
);
...
...
csrc/torch_bindings.cpp
View file @
ed7a29d9
...
@@ -130,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -130,6 +130,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()"
);
") -> ()"
);
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
// Compute MLA decode using cutlass.
ops
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()"
);
ops
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
// Layernorm
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
ops
.
def
(
...
...
tests/kernels/test_cutlass_mla_decode.py
0 → 100644
View file @
ed7a29d9
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
import
vllm._custom_ops
as
ops
from
vllm.platforms
import
current_platform
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Cutlass MLA Requires compute capability of 10 or above."
,
allow_module_level
=
True
)
def
ref_mla
(
out
:
Tensor
,
# (bs, num_heads, v_head_dim)
query
:
Tensor
,
# (bs, num_heads, head_dim)
kv_cache
:
Tensor
,
# (num_blocks, block_size, head_dim)
scale
:
float
,
block_tables
:
Tensor
,
# (bs, max_num_blocks)
seq_lens
:
Tensor
,
# (bs,)
):
bs
,
num_heads
,
v_head_dim
=
out
.
shape
head_dim
=
query
.
shape
[
2
]
for
i
in
range
(
bs
):
# gather and flatten KV-cache
kv
=
kv_cache
[
block_tables
[
i
]]
# (max_num_blocks, block_size, head_dim)
kv
=
kv
.
view
(
1
,
-
1
,
head_dim
)[:,
:
seq_lens
[
i
]]
# (1, seq_len, head_dim)
v
=
kv
[:,
:,
:
v_head_dim
]
q
=
query
[
i
].
view
(
num_heads
,
1
,
head_dim
)
o
=
F
.
scaled_dot_product_attention
(
q
,
kv
,
v
,
scale
=
scale
,
enable_gqa
=
True
)
out
[
i
]
=
o
.
view
(
num_heads
,
v_head_dim
)
return
out
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"mean_seq_len"
,
[
128
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
,
64
,
128
])
def
test_cutlass_mla_decode
(
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
):
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_device
(
'cuda'
)
torch
.
manual_seed
(
42
)
d
=
576
h_q
=
128
dv
=
512
q_nope_dim
=
128
q_pe_dim
=
64
scale
=
(
q_nope_dim
+
q_pe_dim
)
**
(
-
0.5
)
if
varlen
:
seq_lens
=
torch
.
empty
(
bs
).
normal_
(
mean_seq_len
,
mean_seq_len
/
2
)
seq_lens
=
seq_lens
.
clip
(
2
).
to
(
torch
.
int32
)
else
:
seq_lens
=
torch
.
full
((
bs
,
),
mean_seq_len
,
dtype
=
torch
.
int32
)
max_seq_len
=
seq_lens
.
max
().
item
()
block_num
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
# Pad block_num so that small blocks can be packed into full 128-sized
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
# blocks.
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
)
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out_ans
=
torch
.
zeros_like
(
out_ref
)
q_nope
=
q
[:,
:,
:
dv
].
clone
()
q_pe
=
q
[:,
:,
dv
:].
clone
()
ops
.
cutlass_mla_decode
(
out_ans
,
q_nope
,
q_pe
,
kv_cache
,
seq_lens
,
block_table
,
scale
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
vllm/_custom_ops.py
View file @
ed7a29d9
...
@@ -1525,3 +1525,12 @@ def flash_mla_with_kvcache(
...
@@ -1525,3 +1525,12 @@ def flash_mla_with_kvcache(
num_splits
,
num_splits
,
)
)
return
out
,
softmax_lse
return
out
,
softmax_lse
def
cutlass_mla_decode
(
out
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
torch
.
ops
.
_C
.
cutlass_mla_decode
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment