Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b4f65acf
Commit
b4f65acf
authored
Dec 12, 2024
by
Aleksander Dudek
Browse files
[CK TILE] Refactor GemmKernel - naming changes, add problem
parent
f79f727c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
185 additions
and
110 deletions
+185
-110
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+12
-12
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+2
-15
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+5
-5
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+16
-4
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+31
-2
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+1
-1
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+28
-35
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+33
-35
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
+56
-0
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+1
-1
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
b4f65acf
...
...
@@ -15,7 +15,7 @@
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
gemm_basic_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
...
...
@@ -79,17 +79,17 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeK
a
rgs
(
args
.
p_a
,
args
.
p_b
,
args
.
p_c
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
auto
kargs
=
Kernel
::
MakeK
ernelA
rgs
(
args
.
a_ptr
,
args
.
b_ptr
,
args
.
c_ptr
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k
_
batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
b4f65acf
...
...
@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template
<
typename
DataType
>
struct
GemmBasicTypeConfig
;
...
...
@@ -51,20 +52,6 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
gemm_basic_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
...
...
@@ -89,4 +76,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float
gemm_calc
(
gemm_basic_a
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
gemm_calc
(
ck_tile
::
GemmHostA
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
b4f65acf
...
...
@@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_repeat
)
{
gemm_basic_a
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
K
=
K
;
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
b4f65acf
...
...
@@ -16,7 +16,7 @@
#include "batched_gemm.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
batched_gemm
(
const
b
atched
_g
emm
_ka
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
batched_gemm
(
const
B
atched
G
emm
HostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
...
...
@@ -79,9 +79,21 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config&
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
.
a_ptr
,
args
.
b_ptr
,
args
.
c_ptr
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
,
args
.
batch_stride_A
,
args
.
batch_stride_B
,
args
.
batch_stride_C
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
...
...
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
b4f65acf
...
...
@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/problem/gemm_problem.hpp"
template
<
typename
DataType
>
struct
BatchedGemmTypeConfig
;
...
...
@@ -29,8 +30,36 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
b
atched
_g
emm
_ka
rgs
:
public
ck_tile
::
Batched
GemmH
a
rgs
struct
B
atched
G
emm
HostA
rgs
:
public
ck_tile
::
GemmH
ostA
rgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
auto
create_args
(
int
argc
,
char
*
argv
[])
...
...
@@ -60,4 +89,4 @@ auto create_args(int argc, char* argv[])
}
// host API
float
batched_gemm
(
b
atched
_g
emm
_ka
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
batched_gemm
(
B
atched
G
emm
HostA
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
b4f65acf
...
...
@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_repeat
)
{
b
atched
_g
emm
_ka
rgs
args
;
B
atched
G
emm
HostA
rgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
b4f65acf
...
...
@@ -7,20 +7,12 @@
namespace
ck_tile
{
struct
BatchedGemmHargs
:
GemmHargs
{
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
BatchedGemmKernel
:
public
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
{
using
Base
=
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
;
using
GemmK
a
rgs
=
typename
Base
::
GemmK
a
rgs
;
using
GemmK
ernelA
rgs
=
typename
Base
::
GemmK
ernelA
rgs
;
using
ADataType
=
typename
Base
::
ADataType
;
using
BDataType
=
typename
Base
::
BDataType
;
...
...
@@ -33,7 +25,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
struct
BatchedGemmK
a
rgs
:
GemmK
a
rgs
struct
BatchedGemmK
ernelA
rgs
:
GemmK
ernelA
rgs
{
index_t
batch_stride_A
;
index_t
batch_stride_B
;
...
...
@@ -41,33 +33,34 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
index_t
batch_count
;
};
using
Kargs
=
BatchedGemmKargs
;
using
Hargs
=
BatchedGemmHargs
;
using
KernelArgs
=
BatchedGemmKernelArgs
;
__host__
static
constexpr
auto
GridSize
(
const
Hargs
&
k
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
k
.
M
,
k
.
N
,
k
.
batch_count
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
batch_count
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
BatchedGemmKargs
MakeKargs
(
const
Hargs
&
h
)
CK_TILE_HOST
static
constexpr
BatchedGemmKernelArgs
MakeKernelArgs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
Kargs
k
;
k
.
a_ptr
=
h
.
a_ptr
;
k
.
b_ptr
=
h
.
b_ptr
;
k
.
c_ptr
=
h
.
c_ptr
;
k
.
M
=
h
.
M
;
k
.
N
=
h
.
N
;
k
.
K
=
h
.
K
;
k
.
stride_A
=
h
.
stride_A
;
k
.
stride_B
=
h
.
stride_B
;
k
.
stride_C
=
h
.
stride_C
;
k
.
batch_stride_A
=
h
.
batch_stride_A
;
k
.
batch_stride_B
=
h
.
batch_stride_B
;
k
.
batch_stride_C
=
h
.
batch_stride_C
;
k
.
batch_count
=
h
.
batch_count
;
return
k
;
return
BatchedGemmKernelArgs
{{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
},
batch_stride_A
,
batch_stride_B
,
batch_stride_C
,
batch_count
};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -75,7 +68,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
operator
()(
Ka
rgs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelA
rgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
...
...
@@ -83,17 +76,17 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_
start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
;
const
ADataType
*
a_
ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
;
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_
start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
;
const
BDataType
*
b_
ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
;
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_
start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
CDataType
*
c_
ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
this
->
RunGemm
(
a_
start
,
b_start
,
c_start
,
kargs
,
i_m
,
i_n
);
this
->
RunGemm
(
a_
ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
b4f65acf
...
...
@@ -12,19 +12,6 @@
namespace
ck_tile
{
struct
GemmHargs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
...
...
@@ -51,7 +38,7 @@ struct GemmKernel
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
GemmK
a
rgs
struct
GemmK
ernelA
rgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
...
...
@@ -64,17 +51,17 @@ struct GemmKernel
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
GemmK
a
rgs
MakeK
a
rgs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
CK_TILE_HOST
static
constexpr
GemmK
ernelA
rgs
MakeK
ernelA
rgs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
{
return
GemmK
a
rgs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmK
ernelA
rgs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -82,7 +69,7 @@ struct GemmKernel
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmK
a
rgs
&
kargs
)
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmK
ernelA
rgs
&
kargs
)
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -158,7 +145,7 @@ struct GemmKernel
CK_TILE_DEVICE
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmK
a
rgs
&
kargs
)
const
const
GemmK
ernelA
rgs
&
kargs
)
const
{
auto
&&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
...
...
@@ -311,14 +298,26 @@ struct GemmKernel
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
}
/**
* Create tensor views, pad views, tile windows, run gemm and epilogue pipeline
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m M block index
* @param block_idx_n N block index
*
* @return Runs GEMM cooperatively by whole workgroup with CShuffle or Default 2D Epilogue
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmK
a
rgs
&
kargs
,
const
GemmK
ernelA
rgs
&
kargs
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
const
{
// C
onvert pointers to tensor vie
ws
// C
reate Gemm tensor views, pad views and tile windo
ws
auto
&&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
auto
&&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
&&
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
...
...
@@ -328,19 +327,18 @@ struct GemmKernel
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
auto
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
auto
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
// Run GEMM cooperatively by whole workgroup.
auto
c_block_tile
=
auto
&&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
auto
&&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
auto
&&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
auto
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
// Run CShuffle or Default 2D Epilogue
auto
&&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
CK_TILE_DEVICE
void
operator
()(
GemmK
a
rgs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
GemmK
ernelA
rgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
...
...
include/ck_tile/ops/gemm/problem/gemm_problem.hpp
0 → 100644
View file @
b4f65acf
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
struct
Problem
{
CK_TILE_HOST
Problem
()
=
default
;
CK_TILE_HOST
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
M
(
M_
),
N
(
N_
),
K
(
K_
),
stride_A
(
stride_A_
),
stride_B
(
stride_B_
),
stride_C
(
stride_C_
)
{
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
struct
GemmHostArgs
:
public
Problem
{
CK_TILE_HOST
GemmHostArgs
()
=
default
;
CK_TILE_HOST
GemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
index_t
k_batch_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
stride_A_
,
index_t
stride_B_
,
index_t
stride_C_
)
:
Problem
(
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
a_ptr
(
a_ptr_
),
b_ptr
(
b_ptr_
),
c_ptr
(
c_ptr_
),
k_batch
(
k_batch_
)
{
}
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
k_batch
;
};
}
// namespace ck_tile
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
b4f65acf
...
...
@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmH
a
rgs
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmH
ostA
rgs
{
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment