Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
839b27c6
Unverified
Commit
839b27c6
authored
Feb 21, 2025
by
leoneo
Committed by
GitHub
Feb 20, 2025
Browse files
[Kernel]Add streamK for block-quantized CUTLASS kernels (#12978)
parent
34ad27fe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
12 deletions
+44
-12
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
+11
-5
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...utlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+33
-7
No files found.
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
View file @
839b27c6
...
@@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
...
@@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
}
}
template
<
typename
GemmKernel
>
template
<
typename
GemmKernel
>
void
cutlass_gemm_caller
(
torch
::
Device
device
,
void
cutlass_gemm_caller
(
cute
::
Shape
<
int
,
int
,
int
,
int
>
prob_shape
,
torch
::
Device
device
,
cute
::
Shape
<
int
,
int
,
int
,
int
>
prob_shape
,
typename
GemmKernel
::
MainloopArguments
mainloop_args
,
typename
GemmKernel
::
MainloopArguments
mainloop_args
,
typename
GemmKernel
::
EpilogueArguments
epilogue_args
)
{
typename
GemmKernel
::
EpilogueArguments
epilogue_args
,
typename
GemmKernel
::
TileSchedulerArguments
scheduler
=
{})
{
cutlass
::
KernelHardwareInfo
hw_info
;
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
prob_shape
,
mainloop_args
,
epilogue_args
,
hw_info
,
scheduler
};
// Launch the CUTLASS GEMM kernel.
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
View file @
839b27c6
...
@@ -22,8 +22,9 @@ namespace vllm {
...
@@ -22,8 +22,9 @@ namespace vllm {
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
OutType
,
int
GroupSizeM_
,
int
GroupSizeN_
,
int
GroupSizeK_
,
template
<
typename
SchedulerType
,
typename
OutType
,
int
GroupSizeM_
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
>
int
GroupSizeN_
,
int
GroupSizeK_
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
>
struct
cutlass_3x_gemm_fp8_blockwise
{
struct
cutlass_3x_gemm_fp8_blockwise
{
using
GroupSizeM
=
Int
<
GroupSizeM_
>
;
using
GroupSizeM
=
Int
<
GroupSizeM_
>
;
using
GroupSizeN
=
Int
<
GroupSizeN_
>
;
using
GroupSizeN
=
Int
<
GroupSizeN_
>
;
...
@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
...
@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
Persistent
Scheduler
>>
;
Scheduler
Type
>>
;
struct
GemmKernel
:
public
KernelType
{};
struct
GemmKernel
:
public
KernelType
{};
...
@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
...
@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
TileSchedulerArguments
scheduler
;
static
constexpr
bool
UsesStreamKScheduler
=
cute
::
is_same_v
<
typename
GemmKernel
::
TileSchedulerTag
,
cutlass
::
gemm
::
StreamKScheduler
>
;
if
constexpr
(
UsesStreamKScheduler
)
{
using
DecompositionMode
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm90StreamKParams
::
DecompositionMode
;
using
ReductionMode
=
typename
cutlass
::
gemm
::
kernel
::
detail
::
PersistentTileSchedulerSm90StreamKParams
::
ReductionMode
;
scheduler
.
decomposition_mode
=
DecompositionMode
::
StreamK
;
scheduler
.
reduction_mode
=
ReductionMode
::
Nondeterministic
;
}
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
epilogue_args
,
scheduler
);
}
}
template
<
typename
OutType
>
template
<
typename
OutType
>
...
@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
...
@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
torch
::
Tensor
const
&
b_scales
)
{
cutlass_gemm_caller_blockwise
<
auto
k
=
a
.
size
(
1
);
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
auto
n
=
b
.
size
(
1
);
b_scales
);
if
(
k
>
3
*
n
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
cutlass
::
gemm
::
StreamKScheduler
,
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
cutlass
::
gemm
::
PersistentScheduler
,
OutType
,
1
,
128
,
128
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment