Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2d465ec7
Commit
2d465ec7
authored
Apr 29, 2025
by
zhuwenwen
Browse files
skip cutlass_mla_decode
parent
081057de
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
17 deletions
+17
-17
csrc/ops.h
csrc/ops.h
+5
-5
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+5
-5
vllm/_custom_ops.py
vllm/_custom_ops.py
+7
-7
No files found.
csrc/ops.h
View file @
2d465ec7
...
@@ -130,11 +130,11 @@ void advance_step_flashinfer(
...
@@ -130,11 +130,11 @@ void advance_step_flashinfer(
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
//
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch
::
Tensor
const
&
q_pe
,
//
torch::Tensor const& q_pe,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
//
torch::Tensor const& kv_c_and_k_pe_cache,
torch
::
Tensor
const
&
seq_lens
,
//
torch::Tensor const& seq_lens,
torch
::
Tensor
const
&
page_table
,
double
scale
);
//
torch::Tensor const& page_table, double scale);
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
);
torch
::
Tensor
get_cuda_view_from_cpu_tensor
(
torch
::
Tensor
&
cpu_tensor
);
...
...
csrc/torch_bindings.cpp
View file @
2d465ec7
...
@@ -131,11 +131,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -131,11 +131,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
// Compute MLA decode using cutlass.
// Compute MLA decode using cutlass.
ops
.
def
(
//
ops.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
//
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
//
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table, float scale) -> ()"
);
//
" Tensor page_table, float scale) -> ()");
ops
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
//
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
// Layernorm
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
...
...
vllm/_custom_ops.py
View file @
2d465ec7
...
@@ -1533,10 +1533,10 @@ def flash_mla_with_kvcache(
...
@@ -1533,10 +1533,10 @@ def flash_mla_with_kvcache(
return
out
,
softmax_lse
return
out
,
softmax_lse
def
cutlass_mla_decode
(
out
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
#
def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
#
q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
#
seq_lens: torch.Tensor, page_table: torch.Tensor,
scale
:
float
)
->
torch
.
Tensor
:
#
scale: float) -> torch.Tensor:
torch
.
ops
.
_C
.
cutlass_mla_decode
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
#
torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
seq_lens
,
page_table
,
scale
)
#
seq_lens, page_table, scale)
return
out
#
return out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment