Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e757a629
Unverified
Commit
e757a629
authored
Sep 15, 2025
by
Wentao Ye
Committed by
GitHub
Sep 15, 2025
Browse files
[Bug] Fix Cutlass Scaled MM Compilation Error (#24887)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
aae725af
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
41 deletions
+53
-41
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
+23
-19
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
+15
-11
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...utlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+15
-11
No files found.
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
View file @
e757a629
...
@@ -146,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
...
@@ -146,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementBlockScale
=
typename
Gemm
::
ElementBlockScale
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
...
@@ -166,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
...
@@ -166,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
floa
t
*>
(
a_scales
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
ElementBlockScale
cons
t
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
floa
t
*>
(
b_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
ElementBlockScale
cons
t
*>
(
b_scales
.
data_ptr
());
auto
mainloop_args
=
[
&
](){
typename
GemmKernel
::
MainloopArguments
mainloop_args
{};
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
mainloop_args
.
layout_SFA
=
layout_SFA
;
if
(
swap_ab
)
{
mainloop_args
.
layout_SFB
=
layout_SFB
;
return
typename
GemmKernel
::
MainloopArguments
{
if
(
swap_ab
)
{
b_ptr
,
b_stride
,
a_ptr
,
a_stride
,
mainloop_args
.
ptr_A
=
b_ptr
;
b_scales_ptr
,
layout_SFA
,
a_scales_ptr
,
layout_SFB
mainloop_args
.
dA
=
b_stride
;
};
mainloop_args
.
ptr_B
=
a_ptr
;
}
mainloop_args
.
dB
=
a_stride
;
else
{
mainloop_args
.
ptr_SFA
=
b_scales_ptr
;
return
typename
GemmKernel
::
MainloopArguments
{
mainloop_args
.
ptr_SFB
=
a_scales_ptr
;
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
}
else
{
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
mainloop_args
.
ptr_A
=
a_ptr
;
};
mainloop_args
.
dA
=
a_stride
;
}
mainloop_args
.
ptr_B
=
b_ptr
;
}();
mainloop_args
.
dB
=
b_stride
;
mainloop_args
.
ptr_SFA
=
a_scales_ptr
;
mainloop_args
.
ptr_SFB
=
b_scales_ptr
;
}
auto
prob_shape
=
swap_ab
?
cute
::
make_shape
(
n
,
m
,
k
,
1
)
:
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
prob_shape
=
swap_ab
?
cute
::
make_shape
(
n
,
m
,
k
,
1
)
:
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
View file @
e757a629
...
@@ -125,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
...
@@ -125,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementBlockScale
=
typename
Gemm
::
ElementBlockScale
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
...
@@ -143,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
...
@@ -143,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
LayoutSFB
layout_SFB
=
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
ElementBlockScale
const
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
ElementBlockScale
const
*>
(
b_scales
.
data_ptr
());
auto
mainloop_args
=
[
&
](){
typename
GemmKernel
::
MainloopArguments
mainloop_args
{};
return
typename
GemmKernel
::
MainloopArguments
{
mainloop_args
.
ptr_A
=
a_ptr
;
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
mainloop_args
.
dA
=
a_stride
;
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
mainloop_args
.
ptr_B
=
b_ptr
;
};
mainloop_args
.
dB
=
b_stride
;
}();
mainloop_args
.
ptr_SFA
=
a_scales_ptr
;
mainloop_args
.
layout_SFA
=
layout_SFA
;
mainloop_args
.
ptr_SFB
=
b_scales_ptr
;
mainloop_args
.
layout_SFB
=
layout_SFB
;
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
View file @
e757a629
...
@@ -115,6 +115,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
...
@@ -115,6 +115,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementBlockScale
=
typename
Gemm
::
ElementBlockScale
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
...
@@ -135,17 +136,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
...
@@ -135,17 +136,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
LayoutSFB
layout_SFB
=
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
ElementBlockScale
const
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
ElementBlockScale
const
*>
(
b_scales
.
data_ptr
());
auto
mainloop_args
=
[
&
](){
typename
GemmKernel
::
MainloopArguments
mainloop_args
{};
return
typename
GemmKernel
::
MainloopArguments
{
mainloop_args
.
ptr_A
=
a_ptr
;
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
mainloop_args
.
dA
=
a_stride
;
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
mainloop_args
.
ptr_B
=
b_ptr
;
};
mainloop_args
.
dB
=
b_stride
;
}();
mainloop_args
.
ptr_SFA
=
a_scales_ptr
;
mainloop_args
.
layout_SFA
=
layout_SFA
;
mainloop_args
.
ptr_SFB
=
b_scales_ptr
;
mainloop_args
.
layout_SFB
=
layout_SFB
;
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment