Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
114674c1
Unverified
Commit
114674c1
authored
Feb 06, 2025
by
Muhammed Emin Ozturk
Committed by
GitHub
Feb 06, 2025
Browse files
Merge branch 'develop' into muozturk_sk_padding
parents
a199815a
5bb041ba
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
251 additions
and
72 deletions
+251
-72
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+236
-55
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+1
-1
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+14
-16
No files found.
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
114674c1
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
114674c1
...
...
@@ -519,7 +519,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
N
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
M
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
...
...
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
114674c1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
...
...
@@ -14,28 +14,26 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Intrawave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
>
;
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
//
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
//
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
// clang-format on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment