Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a60cf0d0
"...composable_kernel.git" did not exist on "986182fc63e4352dbdc75a2cb2488cefa4a2e976"
Commit
a60cf0d0
authored
Oct 10, 2024
by
Adam Osewski
Browse files
Use AccDataType for Output of MFMA instruction.
parent
b045fad5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
6 deletions
+9
-6
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-1
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
+1
-0
include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr.hpp
include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr.hpp
+3
-3
include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp
...ile/ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+2
-0
No files found.
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
a60cf0d0
...
@@ -18,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
...
@@ -18,7 +18,7 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
using
ADataType
=
ck_tile
::
half_t
;
using
ADataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
BDataType
=
ck_tile
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CDataType
=
floa
t
;
using
CDataType
=
ck_tile
::
half_
t
;
// ToDo: Add more bias config to support different categories of GEMM.
// ToDo: Add more bias config to support different categories of GEMM.
};
};
...
...
example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp
View file @
a60cf0d0
...
@@ -91,6 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
...
@@ -91,6 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
CDataType
,
CDataType
,
GemmShape
,
GemmShape
,
ALayout
,
ALayout
,
...
...
include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr.hpp
View file @
a60cf0d0
...
@@ -18,7 +18,7 @@ struct BlockGemmAsBsCr
...
@@ -18,7 +18,7 @@ struct BlockGemmAsBsCr
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
C
DataType
=
remove_cvref_t
<
typename
Problem
::
C
DataType
>
;
using
Acc
DataType
=
remove_cvref_t
<
typename
Problem
::
Acc
DataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
...
@@ -31,7 +31,7 @@ struct BlockGemmAsBsCr
...
@@ -31,7 +31,7 @@ struct BlockGemmAsBsCr
{
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindowTmp
::
DataType
>
&&
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindowTmp
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindowTmp
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindowTmp
::
DataType
>
&&
std
::
is_same_v
<
C
DataType
,
typename
CBlockTensor
::
DataType
>
,
std
::
is_same_v
<
Acc
DataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
MPerBlock
=
ABlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
...
@@ -195,7 +195,7 @@ struct BlockGemmAsBsCr
...
@@ -195,7 +195,7 @@ struct BlockGemmAsBsCr
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
C
DataType
>
(
c_block_dstr
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
Acc
DataType
>
(
c_block_dstr
);
return
c_block_tensor
;
return
c_block_tensor
;
}
}
...
...
include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp
View file @
a60cf0d0
...
@@ -17,7 +17,7 @@ struct BlockGemmAsBsCrDefaultPolicy
...
@@ -17,7 +17,7 @@ struct BlockGemmAsBsCrDefaultPolicy
{
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
C
DataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
Acc
DataType
,
float
>
)
{
{
#if 0
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize;
...
@@ -45,7 +45,7 @@ struct BlockGemmAsBsCrDefaultPolicy
...
@@ -45,7 +45,7 @@ struct BlockGemmAsBsCrDefaultPolicy
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
BDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
C
DataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
Acc
DataType
,
float
>
)
{
{
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
return
make_tuple
(
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
{},
4
,
1
);
}
}
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
a60cf0d0
...
@@ -42,6 +42,7 @@ struct BlockGemmPipelineProblem
...
@@ -42,6 +42,7 @@ struct BlockGemmPipelineProblem
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
AccDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
ALayout_
,
typename
ALayout_
,
...
@@ -57,6 +58,7 @@ struct UniversalGemmPipelineProblem
...
@@ -57,6 +58,7 @@ struct UniversalGemmPipelineProblem
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment