Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8bba35f2
Commit
8bba35f2
authored
Dec 17, 2024
by
Aleksander Dudek
Browse files
[CK_TILE] Refactor GemmKernel - review changes
parent
b9806269
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
108 additions
and
122 deletions
+108
-122
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+0
-1
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+1
-1
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+1
-2
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+1
-1
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+105
-29
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
+0
-88
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
8bba35f2
...
...
@@ -8,7 +8,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
8bba35f2
...
...
@@ -16,7 +16,7 @@
#include "batched_gemm.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
batched_gemm
(
const
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
...
...
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
8bba35f2
...
...
@@ -8,7 +8,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template
<
typename
DataType
>
struct
BatchedGemmTypeConfig
;
...
...
@@ -57,4 +56,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float
batched_gemm
(
BatchedGemmHostArgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
batched_gemm
(
ck_tile
::
BatchedGemmHostArgs
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
8bba35f2
...
...
@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_repeat
)
{
BatchedGemmHostArgs
args
;
ck_tile
::
BatchedGemmHostArgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
8bba35f2
...
...
@@ -12,6 +12,82 @@
namespace
ck_tile
{
struct
GemmProblem
{
CK_TILE_HOST
GemmProblem
()
=
default
;
CK_TILE_HOST
GemmProblem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
M
(
M_
),
N
(
N_
),
K
(
K_
),
stride_A
(
stride_A_
),
stride_B
(
stride_B_
),
stride_C
(
stride_C_
)
{
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
struct
GemmHostArgs
:
public
GemmProblem
{
CK_TILE_HOST
GemmHostArgs
()
=
default
;
CK_TILE_HOST
GemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
index_t
k_batch_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
GemmProblem
(
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
a_ptr
(
a_ptr_
),
b_ptr
(
b_ptr_
),
c_ptr
(
c_ptr_
),
k_batch
(
k_batch_
)
{
}
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
k_batch
;
};
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
...
...
@@ -147,7 +223,7 @@ struct GemmKernel
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
)
const
{
auto
&
&
a_tensor_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -168,7 +244,7 @@ struct GemmKernel
}
}();
auto
&
&
b_tensor_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -189,7 +265,7 @@ struct GemmKernel
}
}();
auto
&
&
c_tensor_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -214,10 +290,10 @@ struct GemmKernel
}
template
<
typename
TensorView
>
CK_TILE_DEVICE
auto
MakeGemmPadViews
(
TensorView
&
&
views
)
const
CK_TILE_DEVICE
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
const
{
auto
&
&
a_pad_view
=
[
&
]()
{
auto
&
&
a_tensor_view
=
views
.
at
(
I0
);
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
...
...
@@ -234,8 +310,8 @@ struct GemmKernel
}
}();
auto
&
&
b_pad_view
=
[
&
]()
{
auto
&
&
b_tensor_view
=
views
.
at
(
I1
);
const
auto
&
b_pad_view
=
[
&
]()
{
const
auto
&
b_tensor_view
=
views
.
at
(
I1
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
...
...
@@ -252,8 +328,8 @@ struct GemmKernel
}
}();
auto
&
&
c_pad_view
=
[
&
]()
{
auto
&
&
c_tensor_view
=
views
.
at
(
I2
);
const
auto
&
c_pad_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
views
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
...
...
@@ -275,22 +351,22 @@ struct GemmKernel
template
<
typename
PadView
>
CK_TILE_DEVICE
auto
MakeGemmTileWindows
(
PadView
&
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
auto
&
&
a_pad_view
=
views
.
at
(
I0
);
auto
&
&
a_block_window
=
make_tile_window
(
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
auto
&
&
b_pad_view
=
views
.
at
(
I1
);
auto
&
&
b_block_window
=
make_tile_window
(
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
auto
&
&
c_pad_view
=
views
.
at
(
I2
);
auto
&
&
c_block_window
=
make_tile_window
(
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
const
auto
&
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
...
...
@@ -299,15 +375,14 @@ struct GemmKernel
}
/**
* Create tensor views, pad views, tile windows.
* Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m
M block index
* @param block_idx_n
N block index
* @param block_idx_m
The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n
The GEMM's output N dimension tile index processed by this workgroup.
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
...
...
@@ -317,9 +392,10 @@ struct GemmKernel
const
index_t
block_idx_n
)
const
{
// Create Gemm tensor views, pad views and tile windows
auto
&&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
auto
&&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
&&
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
const
auto
&
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
...
...
@@ -327,13 +403,13 @@ struct GemmKernel
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole workgroup.
auto
&
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
auto
&
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
auto
&&
c_block_tile
=
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
// Run
CShuffle or Default 2D Epilogu
e
auto
&&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
// Run
Epilogue Pipelin
e
auto
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
...
...
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
deleted
100644 → 0
View file @
b9806269
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
struct
Problem
{
CK_TILE_HOST
Problem
()
=
default
;
CK_TILE_HOST
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
M
(
M_
),
N
(
N_
),
K
(
K_
),
stride_A
(
stride_A_
),
stride_B
(
stride_B_
),
stride_C
(
stride_C_
)
{
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
struct
GemmHostArgs
:
public
Problem
{
CK_TILE_HOST
GemmHostArgs
()
=
default
;
CK_TILE_HOST
GemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
index_t
k_batch_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
Problem
(
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
a_ptr
(
a_ptr_
),
b_ptr
(
b_ptr_
),
c_ptr
(
c_ptr_
),
k_batch
(
k_batch_
)
{
}
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
k_batch
;
};
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment