Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bbea596d
"...resnet50_tensorflow.git" did not exist on "faf4bbb36d576b0c328f179eb895eb6a48f284e1"
Commit
bbea596d
authored
Jan 24, 2025
by
ThomasNing
Browse files
debug the code
parent
3935554a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
44 deletions
+59
-44
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+3
-0
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+2
-8
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+44
-30
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+10
-6
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
bbea596d
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp
)
target_compile_options
(
tile_example_gemm_basic PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
bbea596d
...
...
@@ -69,12 +69,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
true
,
3
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
@@ -92,11 +91,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
{
throw
std
::
runtime_error
(
"Wrong! Arguments not supported! Skipping gemm!
\n
"
);
}
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
bbea596d
...
...
@@ -40,6 +40,37 @@ struct BlockGemmARegBRegCRegV2
static
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
static
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistributionEncode
()
{
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
return
a_block_dstr_encode
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBBlockDistributionEncode
()
{
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
return
b_block_dstr_encode
;
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockTensor
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
...
...
@@ -51,8 +82,8 @@ struct BlockGemmARegBRegCRegV2
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
auto
a_block_dstr_encode
=
MakeABlockDistribution
();
constexpr
auto
b_block_dstr_encode
=
MakeBBlockDistribution
();
constexpr
auto
a_block_dstr_encode
=
MakeABlockDistribution
Encode
();
constexpr
auto
b_block_dstr_encode
=
MakeBBlockDistribution
Encode
();
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
...
...
@@ -173,38 +204,21 @@ struct BlockGemmARegBRegCRegV2
// auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
// return c_block_tensor;
// }
CK_TILE_DEVICE
static
constexpr
auto
MakeABlockDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockDistributionEncode
()
{
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
return
a_block_dstr
;
return
c_block_dstr_encode
;
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBBlockDistribution
()
{
constexpr
auto
b_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
b_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
b_block_outer_dstr_encoding
,
typename
WG
::
BWarpDstrEncoding
{});
constexpr
auto
b_block_dstr
=
make_static_tile_distribution
(
b_block_dstr_encode
);
return
b_block_dstr
;
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockTensor
>
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
bbea596d
...
...
@@ -149,8 +149,8 @@ struct GemmPipelineAGmemBGmemCRegV1
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
void
*
p_smem_0
,
void
*
p_smem_1
)
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
...
...
@@ -238,8 +238,12 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds
();
constexpr
auto
ALdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistribution
()){};
constexpr
auto
BLdsTileDistr
=
decltype
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistribution
()){};
// constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){};
// constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){};
constexpr
auto
ALdsTileDistr
=
decltype
(
make_static_tile_distribution
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeABlockDistributionEncode
())){};
constexpr
auto
BLdsTileDistr
=
decltype
(
make_static_tile_distribution
(
Policy
::
template
BlockGemm
<
Problem
>
::
MakeBBlockDistributionEncode
())){};
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
...
...
@@ -359,8 +363,8 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE
static
auto
run
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
void
*
p_smem_0
,
void
*
p_smem_1
)
{
return
run
(
a_dram_block_window_tmp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment