Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7409674a
"src/include/amd_inline_asm.hpp" did not exist on "5245a0162bbbe5f49be9cc8f2189f53465f691ec"
Commit
7409674a
authored
Feb 06, 2025
by
ThomasNing
Browse files
Solving the Review comments
parent
c2bb46ff
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
109 deletions
+88
-109
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+0
-3
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+5
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+6
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
+0
-50
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+77
-51
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
7409674a
...
...
@@ -23,17 +23,14 @@
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
constexpr
bool
DoubleSmemBuffer
=
true
;
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
7409674a
...
...
@@ -28,6 +28,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
...
...
@@ -43,6 +45,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
...
...
@@ -57,6 +60,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
constexpr
bool
kPadM
=
false
;
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
7409674a
...
...
@@ -457,8 +457,8 @@ struct GemmKernel
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset
When there are more than 1 batch ne
ed
s
to
split th
e k
.
*
splitk_
batch
_
offset
stands for its the K from which batch
.
* @param splitk_batch_offset
param splitk_batch_offset Utility structure us
ed to
calculat
e k
* batch
offset
per workgroup
.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
...
...
@@ -501,14 +501,16 @@ struct GemmKernel
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset
When there are more than 1 batch ne
ed
s
to
split the k.
*
splitk_batch_offset stands for its the K from which batch
.
* @param splitk_batch_offset
splitk_batch_offset Utility structure us
ed to
calculate k batch
*
offset per workgroup
.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
...
...
@@ -528,7 +530,6 @@ struct GemmKernel
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
;
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp
View file @
7409674a
...
...
@@ -16,56 +16,6 @@ namespace ck_tile {
// member functions instead.
struct
GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
:
public
UniversalGemmBasePolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetBlockGemm
<
Problem
>
())
>
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
using
CLayout
=
typename
Problem
::
CLayout
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
TransposeC
)
{
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
TransposeC
)
{
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
7409674a
...
...
@@ -18,6 +18,15 @@ struct UniversalGemmBasePolicy
static
constexpr
auto
ATileAccessPattern
=
tile_distribution_pattern
::
thread_raked
;
static
constexpr
auto
BTileAccessPattern
=
tile_distribution_pattern
::
thread_raked
;
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
,
index_t
XPerTile
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetGlobalVectorLoadSize
()
{
...
...
@@ -88,6 +97,74 @@ struct UniversalGemmBasePolicy
}
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
using
CLayout
=
typename
Problem
::
CLayout
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
// In this case each thread has just a single item in Ndim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has just a single item in Mdim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
...
...
@@ -198,57 +275,6 @@ struct UniversalGemmBasePolicy
// UniversalGemm Policy
struct
UniversalGemmPipelineAgBgCrPolicy
:
public
UniversalGemmBasePolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetBlockGemm
<
Problem
>
())
>
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
using
CLayout
=
typename
Problem
::
CLayout
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
TransposeC
)
{
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
TransposeC
)
{
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment