Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
970d6d07
Unverified
Commit
970d6d07
authored
Dec 30, 2024
by
Tyler Michael Smith
Committed by
GitHub
Dec 30, 2024
Browse files
[Build][Kernel] Update CUTLASS to v3.6.0 (#11607)
Signed-off-by:
Tyler Michael Smith
<
tyler@neuralmagic.com
>
parent
628ec6c1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
25 additions
and
31 deletions
+25
-31
CMakeLists.txt
CMakeLists.txt
+2
-2
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+9
-9
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+4
-4
csrc/quantization/machete/machete_collective_builder.cuh
csrc/quantization/machete/machete_collective_builder.cuh
+4
-6
csrc/quantization/machete/machete_mainloop.cuh
csrc/quantization/machete/machete_mainloop.cuh
+4
-7
csrc/quantization/machete/machete_prepacked_layout.cuh
csrc/quantization/machete/machete_prepacked_layout.cuh
+2
-3
No files found.
CMakeLists.txt
View file @
970d6d07
...
@@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare
(
FetchContent_Declare
(
cutlass
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG
8aa95dbb888be6d81c6fbf7169718c5244b53227
GIT_TAG
v3.6.0
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW
FALS
E
GIT_SHALLOW
TRU
E
)
)
endif
()
endif
()
FetchContent_MakeAvailable
(
cutlass
)
FetchContent_MakeAvailable
(
cutlass
)
...
...
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
View file @
970d6d07
...
@@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
...
@@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
class
MixedInputKernelScheduleType
(
enum
.
Enum
):
class
MixedInputKernelScheduleType
(
enum
.
Enum
):
TmaWarpSpecialized
MixedInput
=
enum_auto
()
TmaWarpSpecialized
=
enum_auto
()
TmaWarpSpecializedPingpong
MixedInput
=
enum_auto
()
TmaWarpSpecializedPingpong
=
enum_auto
()
TmaWarpSpecializedCooperative
MixedInput
=
enum_auto
()
TmaWarpSpecializedCooperative
=
enum_auto
()
VLLMDataTypeNames
:
Dict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
VLLMDataTypeNames
:
Dict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
...
@@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
...
@@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
KernelScheduleTag
,
# type: ignore
**
{
**
{
MixedInputKernelScheduleType
.
TmaWarpSpecialized
MixedInput
:
MixedInputKernelScheduleType
.
TmaWarpSpecialized
:
"cutlass::gemm::KernelTmaWarpSpecialized
MixedInput
"
,
"cutlass::gemm::KernelTmaWarpSpecialized"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpong
MixedInput
:
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpong
:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong
MixedInput
"
,
"cutlass::gemm::KernelTmaWarpSpecializedPingpong"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
MixedInput
:
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative
MixedInput
"
,
"cutlass::gemm::KernelTmaWarpSpecializedCooperative"
,
}
}
}
}
csrc/quantization/machete/generate.py
View file @
970d6d07
...
@@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
...
@@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperative
MixedInput
,
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
Sch>;
{% for sch in schs %}
{% for sch in schs %}
...
@@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
...
@@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
MixedInput
>
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
>(args.B);
}
}
{%- endfor %}
{%- endfor %}
...
@@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
...
@@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
}; // namespace machete
}; // namespace machete
"""
"""
TmaMI
=
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
MixedInput
TmaMI
=
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperative
TmaCoop
=
EpilogueScheduleType
.
TmaWarpSpecializedCooperative
TmaCoop
=
EpilogueScheduleType
.
TmaWarpSpecializedCooperative
...
@@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
...
@@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
# mostly unique shorter sch_sig
# mostly unique shorter sch_sig
def
generate_terse_sch_sig
(
schedule_config
:
ScheduleConfig
)
->
str
:
def
generate_terse_sch_sig
(
schedule_config
:
ScheduleConfig
)
->
str
:
kernel_terse_names_replace
=
{
kernel_terse_names_replace
=
{
"KernelTmaWarpSpecializedCooperative
MixedInput_
"
:
"TmaMI_"
,
"KernelTmaWarpSpecializedCooperative"
:
"TmaMI_"
,
"TmaWarpSpecializedCooperative_"
:
"TmaCoop_"
,
"TmaWarpSpecializedCooperative_"
:
"TmaCoop_"
,
"StreamKScheduler"
:
"streamK"
,
"StreamKScheduler"
:
"streamK"
,
}
}
...
...
csrc/quantization/machete/machete_collective_builder.cuh
View file @
970d6d07
...
@@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
...
@@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
,
KernelScheduleType
,
cute
::
enable_if_t
<
(
cute
::
enable_if_t
<
(
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
KernelScheduleType
,
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedMixedInput
>
||
KernelTmaWarpSpecializedCooperative
>
)
>>
{
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedPingpongMixedInput
>
||
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
)
>>
{
using
CollectiveOp
=
machete
::
MacheteCollectiveMma
<
using
CollectiveOp
=
machete
::
MacheteCollectiveMma
<
ElementPairA_
,
GmemLayoutA_
,
AlignmentA
,
ElementPairB_
,
GmemLayoutB_
,
ElementPairA_
,
GmemLayoutA_
,
AlignmentA
,
ElementPairB_
,
GmemLayoutB_
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
>
;
StageCountType
,
KernelScheduleType
>
;
};
};
};
// namespace cutlass::gemm::collective
};
// namespace cutlass::gemm::collective
\ No newline at end of file
csrc/quantization/machete/machete_mainloop.cuh
View file @
970d6d07
...
@@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
...
@@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
using
Schedule
=
KernelScheduleType
;
using
Schedule
=
KernelScheduleType
;
static_assert
(
static_assert
(
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedMixedInput
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecialized
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpong
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedPingpongMixedInput
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
||
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
||
cute
::
is_same_v
<
Schedule
,
cute
::
is_same_v
<
Schedule
,
KernelTmaWarpSpecializedCooperative
>
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
"KernelSchedule must be one of the warp specialized policies"
);
"KernelSchedule must be one of the warp specialized policies"
);
public:
public:
...
@@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
...
@@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using
AtomLayoutMNK
=
cute
::
conditional_t
<
using
AtomLayoutMNK
=
cute
::
conditional_t
<
cute
::
is_same_v
<
KernelScheduleType
,
cute
::
is_same_v
<
KernelScheduleType
,
KernelTmaWarpSpecializedCooperative
>
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
...
...
csrc/quantization/machete/machete_prepacked_layout.cuh
View file @
970d6d07
...
@@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
...
@@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
// For coop schedules we have two warp groups cooperatively issuing wgmma
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using
AtomLayoutMNK
=
cute
::
conditional_t
<
using
AtomLayoutMNK
=
cute
::
conditional_t
<
cute
::
is_same_v
<
KernelSchedule
,
cute
::
is_same_v
<
KernelSchedule
,
KernelTmaWarpSpecializedCooperative
>
,
KernelTmaWarpSpecializedCooperativeMixedInput
>
,
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
Layout
<
Shape
<
_2
,
_1
,
_1
>>
,
Layout
<
Shape
<
_1
,
_1
,
_1
>>>
;
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
using
TiledMma
=
decltype
(
cute
::
make_tiled_mma
(
...
@@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
...
@@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
}
}
};
};
};
// namespace machete
};
// namespace machete
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment