Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ef2b53a9
Commit
ef2b53a9
authored
Feb 12, 2025
by
ThomasNing
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into develop
parents
a9df4183
3c7fef7f
Changes
67
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
802 additions
and
68 deletions
+802
-68
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+2
-1
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-1
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+634
-59
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
+52
-0
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-0
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+15
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+8
-0
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+12
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+10
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+11
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+2
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+16
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+8
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+13
-2
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+12
-1
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+1
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+1
-0
include/ck_tile/ops/norm_reduce.hpp
include/ck_tile/ops/norm_reduce.hpp
+1
-0
No files found.
include/ck_tile/ops/flatmm.hpp
View file @
ef2b53a9
...
@@ -9,3 +9,4 @@
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fmha.hpp
View file @
ef2b53a9
...
@@ -44,3 +44,4 @@
...
@@ -44,3 +44,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fused_moe.hpp
View file @
ef2b53a9
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...
@@ -14,6 +15,6 @@
...
@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
ef2b53a9
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
// * this could be larger than actual, since actual tokens are on GPU
//
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
View file @
ef2b53a9
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/
pipeli
ne/moe_sorting_problem.hpp
→
include/ck_tile/ops/fused_moe/
ker
ne
l
/moe_sorting_problem.hpp
View file @
ef2b53a9
...
@@ -25,4 +25,28 @@ struct MoeSortingProblem
...
@@ -25,4 +25,28 @@ struct MoeSortingProblem
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
InternalLoadUnroll_
;
// TODO: need better design(like tile size)
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
};
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
SubTokenTile_
,
// 1,2,4,8, or 0 in the future
bool
SubTokenOneShot_
,
// if we only loop over once or not
bool
LocalExpertMasking_
,
// used in EP case
bool
SkipExpertsWithZeroTokens_
=
true
,
index_t
ExpertTile_
=
0
>
struct
MoeSortingProblemEx
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
SubTokenTile
=
SubTokenTile_
;
static
constexpr
bool
SubTokenOneShot
=
SubTokenOneShot_
;
static
constexpr
bool
LocalExpertMasking
=
LocalExpertMasking_
;
static
constexpr
bool
SkipExpertsWithZeroTokens
=
SkipExpertsWithZeroTokens_
;
static_assert
(
SubTokenTile
==
1
||
SubTokenTile
==
2
||
SubTokenTile
==
4
||
SubTokenTile
==
8
);
static
constexpr
index_t
ExpertTile
=
ExpertTile_
;
// TODO: only used in store out
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm.hpp
View file @
ef2b53a9
...
@@ -46,3 +46,4 @@
...
@@ -46,3 +46,4 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
ef2b53a9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
...
@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_batched"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
{
index_t
batch_stride_A
;
index_t
batch_stride_A
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
ef2b53a9
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -75,6 +76,13 @@ struct GemmKernel
...
@@ -75,6 +76,13 @@ struct GemmKernel
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"gemm"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
GemmPipeline
::
GetName
());
// clang-format on
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
{
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
),
1
,
KBatch
);
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
),
1
,
KBatch
);
...
...
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
View file @
ef2b53a9
...
@@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
...
@@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
}
};
};
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_grouped"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
__host__
static
auto
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
__host__
static
auto
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
size_t
->
std
::
size_t
{
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
ef2b53a9
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -81,6 +82,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -81,6 +82,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using
Base
::
PrefetchStages
;
using
Base
::
PrefetchStages
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrCompV3"
,
BlockSize
,
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
ef2b53a9
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -128,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -128,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrMe"
,
concat
(
'x'
,
MPerBlock
,
NPerBlock
,
KPerBlock
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
using
Base
::
PrefetchStages
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
ef2b53a9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <ostream>
#include <ostream>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
ef2b53a9
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -39,6 +40,18 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -39,6 +40,18 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
index_t
kLdsAlignmentInBytes
=
16
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV1"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
BlockSize
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
@@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
;
kLdsAlignmentInBytes
)
*
kLdsAlignmentInBytes
;
// B tile in LDS
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
ef2b53a9
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -25,6 +26,13 @@ struct GemmPipelineAGmemBGmemCRegV2
...
@@ -25,6 +26,13 @@ struct GemmPipelineAGmemBGmemCRegV2
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV2"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
kBlockSize
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
ef2b53a9
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -35,9 +36,19 @@ struct GemmPipelineProblemBase
...
@@ -35,9 +36,19 @@ struct GemmPipelineProblemBase
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"gemm_problem"
,
concat
(
'x'
,
VectorLoadSize
,
kBlockSize
),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
),
Scheduler
);
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
View file @
ef2b53a9
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -19,6 +20,16 @@ struct TileGemmShape
...
@@ -19,6 +20,16 @@ struct TileGemmShape
static
constexpr
index_t
kM
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kM
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kK
=
BlockTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kK
=
BlockTile
::
at
(
number
<
2
>
{});
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"tile_gemm_shape"
,
concat
(
'x'
,
kM
,
kN
,
kK
,
NumWarps
),
concat
(
'x'
,
BlockWarps
::
at
(
number
<
0
>
{}),
BlockWarps
::
at
(
number
<
1
>
{}),
BlockWarps
::
at
(
number
<
2
>
{})),
concat
(
'x'
,
(
WarpTile
::
at
(
number
<
0
>
{})),
WarpTile
::
at
(
number
<
1
>
{}),
WarpTile
::
at
(
number
<
2
>
{})));
// clang-format on
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/image_to_column.hpp
View file @
ef2b53a9
...
@@ -8,3 +8,4 @@
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/layernorm2d.hpp
View file @
ef2b53a9
...
@@ -11,3 +11,4 @@
...
@@ -11,3 +11,4 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/norm_reduce.hpp
View file @
ef2b53a9
...
@@ -8,3 +8,4 @@
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment