Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b85e1128
Commit
b85e1128
authored
Dec 18, 2024
by
Aleksander Dudek
Browse files
[CK_TILE] Refactor GemmKernel - constness fixes
parent
c1f51cd4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
78 additions
and
90 deletions
+78
-90
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+1
-9
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-1
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+1
-13
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+1
-1
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+47
-18
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+27
-48
No files found.
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
b85e1128
...
@@ -79,15 +79,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -79,15 +79,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
.
a_ptr
,
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
args
.
b_ptr
,
args
.
c_ptr
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
b85e1128
...
@@ -75,4 +75,4 @@ auto create_args(int argc, char* argv[])
...
@@ -75,4 +75,4 @@ auto create_args(int argc, char* argv[])
}
}
// host API
// host API
float
gemm_calc
(
ck_tile
::
GemmHostArgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
b85e1128
...
@@ -79,19 +79,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -79,19 +79,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
.
a_ptr
,
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
args
.
b_ptr
,
args
.
c_ptr
,
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
,
args
.
batch_stride_A
,
args
.
batch_stride_B
,
args
.
batch_stride_C
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
...
...
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
b85e1128
...
@@ -56,4 +56,4 @@ auto create_args(int argc, char* argv[])
...
@@ -56,4 +56,4 @@ auto create_args(int argc, char* argv[])
}
}
// host API
// host API
float
batched_gemm
(
ck_tile
::
BatchedGemmHostArgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
b85e1128
...
@@ -7,6 +7,38 @@
...
@@ -7,6 +7,38 @@
namespace
ck_tile
{
namespace
ck_tile
{
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
BatchedGemmKernel
:
public
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
struct
BatchedGemmKernel
:
public
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
{
{
...
@@ -42,25 +74,22 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
...
@@ -42,25 +74,22 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
BatchedGemmKernelArgs
MakeKernelArgs
(
const
void
*
a_ptr
,
CK_TILE_HOST
static
constexpr
BatchedGemmKernelArgs
const
void
*
b_ptr
,
MakeKernelArgs
(
const
BatchedGemmHostArgs
&
hostArgs
)
void
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
{
return
BatchedGemmKernelArgs
{{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
},
return
BatchedGemmKernelArgs
{{
hostArgs
.
a_ptr
,
batch_stride_A
,
hostArgs
.
b_ptr
,
batch_stride_B
,
hostArgs
.
c_ptr
,
batch_stride_C
,
hostArgs
.
M
,
batch_count
};
hostArgs
.
N
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
},
hostArgs
.
batch_stride_A
,
hostArgs
.
batch_stride_B
,
hostArgs
.
batch_stride_C
,
hostArgs
.
batch_count
};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
b85e1128
...
@@ -56,38 +56,6 @@ struct GemmHostArgs : public GemmProblem
...
@@ -56,38 +56,6 @@ struct GemmHostArgs : public GemmProblem
index_t
k_batch
;
index_t
k_batch
;
};
};
struct
BatchedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
CK_TILE_HOST
BatchedGemmHostArgs
()
=
default
;
CK_TILE_HOST
BatchedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
k_batch_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
,
ck_tile
::
index_t
batch_stride_A_
,
ck_tile
::
index_t
batch_stride_B_
,
ck_tile
::
index_t
batch_stride_C_
,
ck_tile
::
index_t
batch_count_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
k_batch_
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
),
batch_stride_A
(
batch_stride_A_
),
batch_stride_B
(
batch_stride_B_
),
batch_stride_C
(
batch_stride_C_
),
batch_count
(
batch_count_
)
{
}
ck_tile
::
index_t
batch_stride_A
;
ck_tile
::
index_t
batch_stride_B
;
ck_tile
::
index_t
batch_stride_C
;
ck_tile
::
index_t
batch_count
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
struct
GemmKernel
{
{
...
@@ -127,18 +95,30 @@ struct GemmKernel
...
@@ -127,18 +95,30 @@ struct GemmKernel
index_t
stride_C
;
index_t
stride_C
;
};
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
void
*
a_ptr
,
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
GemmHostArgs
&
hostArgs
)
const
void
*
b_ptr
,
void
*
c_ptr
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_C
)
{
{
return
GemmKernelArgs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmKernelArgs
{
hostArgs
.
a_ptr
,
hostArgs
.
b_ptr
,
hostArgs
.
c_ptr
,
hostArgs
.
M
,
hostArgs
.
N
,
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
};
}
}
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
...
@@ -365,8 +345,8 @@ struct GemmKernel
...
@@ -365,8 +345,8 @@ struct GemmKernel
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
{
i_n
,
0
});
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
const
auto
&
c_block_window
=
make_tile_window
(
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
...
@@ -394,8 +374,7 @@ struct GemmKernel
...
@@ -394,8 +374,7 @@ struct GemmKernel
// Create Gemm tensor views, pad views and tile windows
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
const
auto
&
gemm_tile_windows
=
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr
[
GetSmemSize
()];
...
@@ -405,11 +384,11 @@ struct GemmKernel
...
@@ -405,11 +384,11 @@ struct GemmKernel
// Run GEMM cooperatively by whole workgroup.
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
auto
c_block_tile
=
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
// Run Epilogue Pipeline
// Run Epilogue Pipeline
auto
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment