Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
824809c1
Commit
824809c1
authored
Oct 10, 2024
by
Adam Osewski
Browse files
Fixes after merge.
parent
4085e3d0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
31 additions
and
47 deletions
+31
-47
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
+5
-15
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+7
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+3
-3
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
...m/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+7
-9
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+5
-15
No files found.
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
View file @
824809c1
...
@@ -48,14 +48,10 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -48,14 +48,10 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
using
BaseGemmPipeline
=
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
BDataType
,
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
CDataType
,
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
GemmShape
,
ALayout
,
BLayout
,
CLayout
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
...
@@ -71,14 +67,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -71,14 +67,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
GemmShape
,
GemmShape
,
ALayout
,
Traits
,
BLayout
,
CLayout
,
kPadA
,
kPadB
,
kPadC
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>>
;
tail_number_v
>>
;
...
...
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
824809c1
...
@@ -164,7 +164,13 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -164,7 +164,13 @@ int run_gemm_example(int argc, char* argv[])
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
824809c1
...
@@ -18,7 +18,7 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -18,7 +18,7 @@ struct BlockGemmASmemBSmemCRegV1
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
Acc
DataType
=
remove_cvref_t
<
typename
Problem
::
Acc
DataType
>
;
using
C
DataType
=
remove_cvref_t
<
typename
Problem
::
C
DataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
...
@@ -31,7 +31,7 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -31,7 +31,7 @@ struct BlockGemmASmemBSmemCRegV1
{
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
::
DataType
>
&&
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
Acc
DataType
,
typename
CBlockTensor
::
DataType
>
,
std
::
is_same_v
<
C
DataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
...
@@ -195,7 +195,7 @@ struct BlockGemmASmemBSmemCRegV1
...
@@ -195,7 +195,7 @@ struct BlockGemmASmemBSmemCRegV1
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
Acc
DataType
>
(
c_block_dstr
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
C
DataType
>
(
c_block_dstr
);
return
c_block_tensor
;
return
c_block_tensor
;
}
}
...
...
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp
View file @
824809c1
...
@@ -17,7 +17,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
...
@@ -17,7 +17,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
Acc
DataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
C
DataType
,
float
>
)
{
{
#if 0
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize;
...
@@ -45,7 +45,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
...
@@ -45,7 +45,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
Acc
DataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
C
DataType
,
float
>
)
{
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
}
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
824809c1
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -87,7 +87,7 @@ struct BaseGemmPipelineAgBgCrMem
...
@@ -87,7 +87,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
824809c1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
bool
kPadA_
,
template
<
bool
kPadA_
,
bool
kPadB_
,
bool
kPadB_
,
bool
kPadC_
,
bool
kPadC_
,
typename
Layout
A
_
,
typename
A
Layout_
,
typename
Layout
B
_
,
typename
B
Layout_
,
typename
Layout
C
_
>
typename
C
Layout_
>
struct
TileGemmTraits
struct
TileGemmTraits
{
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadC
=
kPadC_
;
using
Layout
A
=
Layout
A
_
;
using
A
Layout
=
A
Layout_
;
using
Layout
B
=
Layout
B
_
;
using
B
Layout
=
B
Layout_
;
using
Layout
C
=
Layout
C
_
;
using
C
Layout
=
C
Layout_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
824809c1
...
@@ -69,14 +69,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -69,14 +69,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
false
,
kPadC
>>
;
using
BaseGemmPipeline
=
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadA
,
kPadB
,
kPadC
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
BlockGemmPipelineProblem
<
ADataType
,
BDataType
,
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
CDataType
,
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>
;
GemmShape
,
ALayout
,
BLayout
,
CLayout
>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
...
@@ -90,14 +86,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
...
@@ -90,14 +86,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
GemmShape
,
GemmShape
,
ALayout
,
Traits
,
BLayout
,
CLayout
,
kPadA
,
kPadB
,
kPadC
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>>
;
tail_number_v
>>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment