Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f18cfec4
Unverified
Commit
f18cfec4
authored
Feb 14, 2025
by
Haocong WANG
Committed by
GitHub
Feb 14, 2025
Browse files
Merge branch 'develop' into update_cka8w8_uc
parents
4599ee00
0e5e29c4
Changes
163
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
967 additions
and
106 deletions
+967
-106
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-0
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+1
-0
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/utils.hpp
include/ck_tile/ops/common/utils.hpp
+34
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+2
-1
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-1
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+634
-59
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
+52
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+3
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+59
-29
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+15
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+101
-8
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+12
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+11
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+12
-0
No files found.
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
f18cfec4
...
...
@@ -14,12 +14,15 @@ namespace ck_tile {
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
IndexType
>&
local_expert_mask
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
unit_size
)
const
index_t
unit_size
,
bool
local_expert_masking
,
bool
skip_experts_with_zero_token
=
true
)
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
// count number of unit-size slices in this expert
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
// count the tokens used in this expert
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
// TODO: above 2 buffer seems duplicated
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
...
...
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
int
curr_expert_id
=
0
;
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
if
(
local_expert_masking
)
{
if
(
local_expert_mask
(
e
)
==
0
)
continue
;
}
if
(
skip_experts_with_zero_token
)
{
if
(
expert_slice_idxs
[
e
]
==
0
)
{
curr_expert_id
++
;
continue
;
}
}
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
...
...
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
out_expert_id
[
s
]
=
e
;
out_expert_id
[
s
]
=
curr_expert_id
;
unit_cnt
++
;
}
out_expert_id
+=
expert_slices
[
e
];
curr_expert_id
++
;
}
unit_cnt
*=
unit_size
;
return
;
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
f18cfec4
...
...
@@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/batched_transpose.hpp
View file @
f18cfec4
...
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/common.hpp
View file @
f18cfec4
...
...
@@ -5,3 +5,4 @@
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/common/utils.hpp
0 → 100644
View file @
f18cfec4
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// clang-format off
template
<
typename
T
>
struct
typeToStr
;
template
<
>
struct
typeToStr
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
typeToStr
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
typeToStr
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
typeToStr
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
typeToStr
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
typeToStr
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
template
<
typename
ADataType_
,
typename
BDataType_
>
std
::
string
gemm_prec_str
()
{
std
::
string
base_str
=
std
::
string
(
typeToStr
<
ADataType_
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType_
,
BDataType_
>
)
{
base_str
+=
"_"
+
std
::
string
(
typeToStr
<
BDataType_
>::
name
);
}
return
base_str
;
}
}
// namespace ck_tile
include/ck_tile/ops/elementwise.hpp
View file @
f18cfec4
...
...
@@ -6,3 +6,4 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/epilogue.hpp
View file @
f18cfec4
...
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/flatmm.hpp
View file @
f18cfec4
...
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fmha.hpp
View file @
f18cfec4
...
...
@@ -44,3 +44,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fused_moe.hpp
View file @
f18cfec4
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...
...
@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
f18cfec4
...
...
@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
f18cfec4
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/
pipeli
ne/moe_sorting_problem.hpp
→
include/ck_tile/ops/fused_moe/
ker
ne
l
/moe_sorting_problem.hpp
View file @
f18cfec4
...
...
@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
SubTokenTile_
,
// 1,2,4,8, or 0 in the future
bool
SubTokenOneShot_
,
// if we only loop over once or not
bool
LocalExpertMasking_
,
// used in EP case
bool
SkipExpertsWithZeroTokens_
=
true
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
SubTokenTile
=
SubTokenTile_
;
static
constexpr
bool
SubTokenOneShot
=
SubTokenOneShot_
;
static
constexpr
bool
LocalExpertMasking
=
LocalExpertMasking_
;
static
constexpr
bool
SkipExpertsWithZeroTokens
=
SkipExpertsWithZeroTokens_
;
static_assert
(
SubTokenTile
==
1
||
SubTokenTile
==
2
||
SubTokenTile
==
4
||
SubTokenTile
==
8
);
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
f18cfec4
...
...
@@ -29,6 +29,8 @@
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
...
...
@@ -46,3 +48,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
f18cfec4
...
...
@@ -14,24 +14,54 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBRegCRegV1DefaultPolicy
>
struct
BlockGemmARegBRegCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
private:
template
<
typename
PipelineProblem_
,
typename
GemmPolicy_
>
struct
GemmTraits_
{
using
Problem
=
remove_cvref_t
<
PipelineProblem_
>
;
using
Policy
=
remove_cvref_t
<
GemmPolicy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
static
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
static
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WarpGemm
::
kM
);
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WarpGemm
::
kK
;
static
constexpr
index_t
KPack
=
WarpGemm
::
kKPerThread
;
};
public:
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Traits
=
GemmTraits_
<
Problem
,
Policy
>
;
using
WarpGemm
=
typename
Traits
::
WarpGemm
;
using
BlockGemmShape
=
typename
Traits
::
BlockGemmShape
;
using
ADataType
=
remove_cvref_t
<
typename
Traits
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Traits
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Traits
::
CDataType
>
;
static
constexpr
index_t
KIterPerWarp
=
Traits
::
KIterPerWarp
;
static
constexpr
index_t
MIterPerWarp
=
Traits
::
MIterPerWarp
;
static
constexpr
index_t
NIterPerWarp
=
Traits
::
NIterPerWarp
;
static
constexpr
index_t
MWarp
=
Traits
::
MWarp
;
static
constexpr
index_t
NWarp
=
Traits
::
NWarp
;
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
{
...
...
@@ -43,7 +73,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
W
G
::
AWarpDstrEncoding
{});
a_block_outer_dstr_encoding
,
typename
W
arpGemm
::
AWarpDstrEncoding
{});
return
a_block_dstr_encode
;
}
...
...
@@ -58,7 +88,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
W
G
::
BWarpDstrEncoding
{});
b_block_outer_dstr_encoding
,
typename
W
arpGemm
::
BWarpDstrEncoding
{});
return
b_block_dstr_encode
;
}
...
...
@@ -73,7 +103,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
W
G
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
W
arpGemm
::
CWarpDstrEncoding
{});
return
c_block_dstr_encode
;
}
...
...
@@ -112,13 +142,13 @@ struct BlockGemmARegBRegCRegV1
.
get_static_tile_distribution_encoding
())
>>
,
"C distribution is wrong!"
);
using
AWarpDstr
=
typename
W
G
::
AWarpDstr
;
using
BWarpDstr
=
typename
W
G
::
BWarpDstr
;
using
CWarpDstr
=
typename
W
G
::
CWarpDstr
;
using
AWarpDstr
=
typename
W
arpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
W
arpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
W
arpGemm
::
CWarpDstr
;
using
AWarpTensor
=
typename
W
G
::
AWarpTensor
;
using
BWarpTensor
=
typename
W
G
::
BWarpTensor
;
using
CWarpTensor
=
typename
W
G
::
CWarpTensor
;
using
AWarpTensor
=
typename
W
arpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
W
arpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
W
arpGemm
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
...
...
@@ -157,7 +187,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
W
G
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
W
arpGemm
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
...
...
@@ -180,7 +210,7 @@ struct BlockGemmARegBRegCRegV1
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
W
G
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
W
arpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
f18cfec4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_batched"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
index_t
batch_stride_A
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
f18cfec4
...
...
@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -75,6 +76,13 @@ struct GemmKernel
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"gemm"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
GemmPipeline
::
GetName
());
// clang-format on
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
),
1
,
KBatch
);
...
...
@@ -455,7 +463,9 @@ struct GemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
...
...
@@ -465,7 +475,7 @@ struct GemmKernel
CK_TILE_DEVICE
static
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
smem_ptr
,
void
*
smem_ptr
_0
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
...
...
@@ -483,15 +493,67 @@ struct GemmKernel
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr
);
c_block_window
,
c_block_tile
,
smem_ptr_0
);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
void
RunGemm2LDS
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
__restrict__
smem_ptr_0
,
void
*
__restrict__
smem_ptr_1
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
,
smem_ptr_1
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr_0
);
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
...
...
@@ -509,11 +571,27 @@ struct GemmKernel
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr_0
[
GetSmemSize
()];
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
if
(
kargs
.
k_batch
==
1
)
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
{
RunGemm2LDS
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
else
{
...
...
@@ -522,8 +600,23 @@ struct GemmKernel
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
{
RunGemm2LDS
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
}
...
...
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
View file @
f18cfec4
...
...
@@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
};
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_grouped"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
__host__
static
auto
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
size_t
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
f18cfec4
...
...
@@ -41,20 +41,26 @@ struct GemmPipelineAgBgCrImplBase
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
DstBlockTile
&
dst_block_tile
,
const
SrcTileWindow
&
lds_tile_window
)
const
{
load_tile
(
dst_block_tile
,
lds_tile_window
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
__restrict__
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
BDataType
*
__restrict__
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
f18cfec4
...
...
@@ -10,6 +10,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -75,12 +76,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Problem
::
DoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrCompV3"
,
BlockSize
,
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment