Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
853a5183
Unverified
Commit
853a5183
authored
Jan 30, 2025
by
arai713
Committed by
GitHub
Jan 30, 2025
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
5574422a
e6d41804
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
28 deletions
+20
-28
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
+20
-28
No files found.
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <sstream>
#include <sstream>
...
@@ -26,12 +26,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
...
@@ -26,12 +26,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
struct
GroupedGemKernelParam
struct
GroupedGemKernelParam
{
{
static
const
bool
kPadM
=
false
;
static
const
bool
kPadM
=
false
;
static
const
bool
kPadN
=
false
;
static
const
bool
kPadN
=
false
;
static
const
bool
kPadK
=
false
;
static
const
bool
kPadK
=
false
;
static
const
bool
kTilePermute
=
false
;
static
const
ck_tile
::
index_t
kOutputRank
=
2
;
static
const
int
kBlockPerCu
=
1
;
static
const
int
kBlockPerCu
=
1
;
static
const
ck_tile
::
index_t
M_Tile
=
128
;
static
const
ck_tile
::
index_t
M_Tile
=
128
;
...
@@ -60,26 +57,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
...
@@ -60,26 +57,6 @@ class TestCkTileGroupedGemm : public ::testing::Test
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
template
<
typename
CLayout
>
using
GemmEpilogue
=
std
::
conditional_t
<
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemKernelParam
::
kPadM
,
GroupedGemKernelParam
::
kPadN
,
GroupedGemKernelParam
::
kTilePermute
,
GroupedGemKernelParam
::
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemKernelParam
::
kPadM
,
GroupedGemKernelParam
::
kPadN
>>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
GroupedGemKernelParam
::
kPadM
,
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
GroupedGemKernelParam
::
kPadM
,
GroupedGemKernelParam
::
kPadN
,
GroupedGemKernelParam
::
kPadN
,
...
@@ -100,10 +77,25 @@ class TestCkTileGroupedGemm : public ::testing::Test
...
@@ -100,10 +77,25 @@ class TestCkTileGroupedGemm : public ::testing::Test
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>::
BlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
GroupedGemKernelParam
::
M_Warp
,
GroupedGemKernelParam
::
N_Warp
,
GroupedGemKernelParam
::
M_Warp_Tile
,
GroupedGemKernelParam
::
N_Warp_Tile
,
GroupedGemKernelParam
::
K_Warp_Tile
,
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>::
TransposeC
>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>
,
GemmEpilogue
<
CLayout
>>
;
GemmEpilogue
<
ALayout
,
BLayout
,
CLayout
>>
;
using
grouped_gemm_kargs
=
ck_tile
::
GroupedGemmHostArgs
;
using
grouped_gemm_kargs
=
ck_tile
::
GroupedGemmHostArgs
;
std
::
size_t
GetWorkspaceSize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
std
::
size_t
GetWorkspaceSize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment