Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4525c5d7
Commit
4525c5d7
authored
Dec 02, 2024
by
coderfeli
Browse files
merge upstream
parents
a8d88d8d
44828b7c
Changes
308
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
420 additions
and
230 deletions
+420
-230
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+225
-0
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
+16
-3
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
+55
-4
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+13
-12
test/grouped_gemm/CMakeLists.txt
test/grouped_gemm/CMakeLists.txt
+0
-6
test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp
test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp
+28
-18
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
+13
-118
test/grouped_gemm/test_grouped_gemm_util.hpp
test/grouped_gemm/test_grouped_gemm_util.hpp
+70
-69
No files found.
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
0 → 100644
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template
<
typename
Tuple
>
class
TestCkTileBatchedGemm
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
invoke_batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
kM
,
TilePartitioner
::
kN
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
{
std
::
cout
<<
"Launching kernel with args:"
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
}
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
public:
void
Run
(
const
int
M
,
const
int
N
,
const
int
K
,
int
StrideA
=
128
,
int
StrideB
=
128
,
int
StrideC
=
128
,
const
int
BatchStrideA
=
32768
,
const
int
BatchStrideB
=
16384
,
const
int
BatchStrideC
=
32768
,
const
int
BatchCount
=
16
)
{
using
namespace
ck_tile
::
literals
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
batch_count_
,
row
,
col
},
{
batch_stride
,
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
ck_tile
::
HostTensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
K
,
StrideA
,
BatchStrideA
,
ALayout
{}));
ck_tile
::
HostTensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
BatchCount
,
K
,
N
,
StrideB
,
BatchStrideB
,
BLayout
{}));
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_dev_result
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck_tile
::
FillUniformDistribution
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck_tile
::
DeviceMem
a_m_k_dev_buf
(
a_m_k
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
b_k_n_dev_buf
(
b_k_n
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
c_m_n_dev_buf
(
c_m_n_dev_result
.
get_element_space_size_in_bytes
());
a_m_k_dev_buf
.
ToDevice
(
a_m_k
.
data
());
b_k_n_dev_buf
.
ToDevice
(
b_k_n
.
data
());
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
batched_gemm_kargs
kargs
{
a_m_k_dev_buf
.
GetDeviceBuffer
(),
b_k_n_dev_buf
.
GetDeviceBuffer
(),
c_m_n_dev_buf
.
GetDeviceBuffer
(),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideC
,
BatchCount
};
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
kargs
,
ck_tile
::
stream_config
{
nullptr
,
false
});
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
<<
" StrideA ="
<<
StrideA
<<
" StrideB ="
<<
StrideB
<<
" StrideC ="
<<
StrideC
<<
" BatchStrideA ="
<<
BatchStrideA
<<
" BatchStrideB ="
<<
BatchStrideB
<<
" BatchStrideC ="
<<
BatchStrideC
<<
" BatchCount ="
<<
BatchCount
<<
std
::
endl
;
c_m_n_dev_buf
.
FromDevice
(
c_m_n_dev_result
.
data
());
bool
pass
=
true
;
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_host_tensor_descriptor
(
BatchCount
,
M
,
N
,
StrideC
,
BatchStrideC
,
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
const
auto
b_n_k
=
b_k_n
.
transpose
({
0
,
2
,
1
});
ck_tile
::
reference_batched_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k
,
b_n_k
,
c_m_n_host_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_host_ref
);
EXPECT_TRUE
(
pass
);
}
};
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
View file @
4525c5d7
...
...
@@ -11,8 +11,20 @@
using
F16
=
ck_tile
::
half_t
;
using
F32
=
float
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
static
constexpr
auto
Intrawave
=
ck_tile
::
GemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
ck_tile
::
GemmPipelineScheduler
::
Interwave
;
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipelineIntrawave
:
public
TestCkTileGemmMemPipeline
<
Tuple
,
Intrawave
>
{
};
template
<
typename
Tuple
>
class
TestCkTileGemmMemPipelineInterwave
:
public
TestCkTileGemmMemPipeline
<
Tuple
,
Interwave
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
...
...
@@ -24,6 +36,7 @@ using KernelTypes = ::testing::Types<
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipeline
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipelineIntrawave
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestCkTileGemmMemPipelineInterwave
,
KernelTypes
);
#include "test_gemm_mem_pipeline_ut_cases.inc"
test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
SmallM
)
//------------------------------------------------------------------------------------------------
// INTERWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
320
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
432
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipelineInterwave
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
constexpr
int
K
=
512
;
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
);
}
//------------------------------------------------------------------------------------------------
// INTRAWAVE SCHEDULER
//------------------------------------------------------------------------------------------------
TYPED_TEST
(
TestCkTileGemmMemPipelineIntrawave
,
SmallM
)
{
std
::
vector
<
int
>
Ms
{
1
,
2
,
3
,
4
,
5
,
6
};
constexpr
int
N
=
1024
;
...
...
@@ -10,7 +61,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
MidLargeM
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
Intrawave
,
MidLargeM
)
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
constexpr
int
N
=
1024
;
...
...
@@ -20,7 +71,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
PaddK
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
Intrawave
,
PaddK
)
{
std
::
vector
<
int
>
Ms
{
127
};
constexpr
int
N
=
1024
;
...
...
@@ -30,7 +81,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK)
this
->
Run
(
M
,
N
,
K
);
}
TYPED_TEST
(
TestCkTileGemmMemPipeline
,
Regular
)
TYPED_TEST
(
TestCkTileGemmMemPipeline
Intrawave
,
Regular
)
{
std
::
vector
<
int
>
Ms
{
512
};
constexpr
int
N
=
1024
;
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
4525c5d7
...
...
@@ -11,20 +11,21 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template
<
typename
Tuple
>
template
<
typename
Tuple
,
ck_tile
::
GemmPipelineScheduler
Scheduler_
>
class
TestCkTileGemmMemPipeline
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
// TODO: expose tile size through test t-param ?
struct
gemm_
basic_
args
struct
gemm_args
{
const
void
*
p_a
;
const
void
*
p_b
;
...
...
@@ -38,7 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile
::
index_t
stride_C
;
};
void
invoke_gemm
(
const
gemm_
basic_
args
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_gemm
(
const
gemm_args
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
...
...
@@ -89,7 +90,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
AccDataType
,
GemmShape
,
Traits
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
Scheduler
,
has_hot_loop_v
,
tail_number_v
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
...
...
@@ -288,7 +289,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
gemm_
basic_
args
args
;
gemm_args
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
...
...
test/grouped_gemm/CMakeLists.txt
View file @
4525c5d7
...
...
@@ -6,12 +6,6 @@ if(result EQUAL 0)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_splitk
)
endif
()
add_gtest_executable
(
test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance
)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_two_stage_splitk
)
endif
()
add_gtest_executable
(
test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance
)
...
...
test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
...
...
@@ -10,25 +10,35 @@
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F8
=
ck
::
f8_t
;
using
I8
=
int8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
RRR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
using
RRR_F16_F16_F16_LargeK
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16_LargeK
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
const
std
::
vector
<
int
>
KBATCH
{
1
,
2
,
3
,
5
,
8
};
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_MK_KN
,
RRR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_MK_NK
,
RCR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_LargeK_MK_KN
,
RRR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_LargeK_MK_NK
,
RCR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
template
<
typename
Tuple
>
class
TestGroupedGemm
:
public
ck
::
test
::
TestGroupedGemm
<
Tuple
>
{
};
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
BF16
,
I8
,
BF16
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
BF16
,
I8
,
BF16
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F8
,
F16
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F8
,
F16
,
F16
>
>
;
// clang-format on
TYPED_TEST_SUITE
(
TestGroupedGemm
,
KernelTypes
);
#include "test_grouped_gemm_ut_cases.inc"
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
View file @
4525c5d7
#pragma once
T
EST_P
(
RRR_F16_F16_F16
,
TinyCases
)
T
YPED_TEST
(
TestGroupedGemm
,
TinyCases
)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
constexpr
int
N
=
768
;
...
...
@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
()
);
this
->
Run
(
Ms
,
Ns
,
Ks
);
}
T
EST_P
(
RRR_F16_F16_F16
,
SmallCases
)
T
YPED_TEST
(
TestGroupedGemm
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
constexpr
int
N
=
768
;
...
...
@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
()
);
this
->
Run
(
Ms
,
Ns
,
Ks
);
}
T
EST_P
(
RRR_F16_F16_F16
,
MidCases
)
T
YPED_TEST
(
TestGroupedGemm
,
MidCases
)
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
constexpr
int
N
=
768
;
...
...
@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases)
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
()
);
this
->
Run
(
Ms
,
Ns
,
Ks
);
}
T
EST_P
(
RRR_F16_F16_F16
,
Regular
)
T
YPED_TEST
(
TestGroupedGemm
,
Regular
)
{
const
std
::
vector
<
int
>
Ms
{
64
,
128
,
256
};
constexpr
int
N
=
768
;
...
...
@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular)
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
()
);
this
->
Run
(
Ms
,
Ns
,
Ks
);
}
T
EST_P
(
RRR_F16_F16_F16
,
MNKPadded
)
T
YPED_TEST
(
TestGroupedGemm
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
...
...
@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
()
);
this
->
Run
(
Ms
,
Ns
,
Ks
);
}
TEST_P
(
RCR_F16_F16_F16
,
TinyCases
)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
constexpr
int
N
=
768
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
constexpr
int
N
=
768
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
MidCases
)
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
constexpr
int
N
=
768
;
constexpr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
Regular
)
{
const
std
::
vector
<
int
>
Ms
{
32
,
64
,
128
,
256
};
constexpr
int
N
=
768
;
constexpr
int
K
=
320
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
TYPED_TEST
(
TestGroupedGemm
,
TestLargeKBatch
)
{
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
constexpr
int
N
=
768
;
...
...
@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
{
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
constexpr
int
N
=
768
;
constexpr
int
K
=
4096
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
k_batches_
=
{
32
,
64
};
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
()
);
this
->
Run
(
Ms
,
Ns
,
Ks
);
}
test/grouped_gemm/test_grouped_gemm_util.hpp
View file @
4525c5d7
...
...
@@ -22,7 +22,6 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace
ck
{
namespace
test
{
...
...
@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range)
}
template
<
typename
Tuple
>
class
TestGroupedGemm
:
public
testing
::
Test
WithParam
<
int
>
class
TestGroupedGemm
:
public
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
...
...
@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
EDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
public:
static
constexpr
bool
verify_
=
true
;
static
constexpr
int
init_method_
=
1
;
//
decimal
value initialization
static
constexpr
int
init_method_
=
1
;
//
integer
value initialization
static
constexpr
bool
log_
=
false
;
static
constexpr
bool
bench_
=
false
;
// measure kernel performance
static
constexpr
int
n_warmup_
=
0
;
static
constexpr
int
n_iter_
=
1
;
std
::
vector
<
int
>
k_batches_
;
void
SetUp
()
override
{}
void
SetUp
()
override
{
k_batches_
=
{
1
,
2
,
3
,
5
,
8
};
}
private:
template
<
typename
Layout
>
void
SetStrides
(
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
rows
,
const
std
::
vector
<
int
>&
cols
)
const
{
if
(
std
::
is_same_v
<
Layout
,
Row
>
)
{
for
(
const
auto
c
:
cols
)
{
strides
.
emplace_back
(
c
);
}
}
else
if
(
std
::
is_same_v
<
Layout
,
Col
>
)
{
for
(
const
auto
r
:
rows
)
{
strides
.
emplace_back
(
r
);
}
}
}
public:
void
Run
(
const
std
::
vector
<
int
>&
Ms
,
const
std
::
vector
<
int
>&
Ns
,
const
std
::
vector
<
int
>&
Ks
,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
,
int
n_warmup
=
1
,
int
n_iter
=
10
)
const
std
::
vector
<
int
>&
StrideAs
=
{},
const
std
::
vector
<
int
>&
StrideBs
=
{},
const
std
::
vector
<
int
>&
StrideCs
=
{})
{
std
::
vector
<
int
>
stride_as
=
StrideAs
;
std
::
vector
<
int
>
stride_bs
=
StrideBs
;
std
::
vector
<
int
>
stride_cs
=
StrideCs
;
if
(
stride_as
.
empty
())
{
SetStrides
<
ALayout
>
(
stride_as
,
Ms
,
Ks
);
}
if
(
stride_bs
.
empty
())
{
SetStrides
<
BLayout
>
(
stride_bs
,
Ks
,
Ns
);
}
if
(
stride_cs
.
empty
())
{
SetStrides
<
ELayout
>
(
stride_cs
,
Ms
,
Ns
);
}
RunSingle
(
Ms
,
Ns
,
Ks
,
stride_as
,
stride_bs
,
stride_cs
,
k_batches_
);
}
void
RunSingle
(
const
std
::
vector
<
int
>&
Ms
,
const
std
::
vector
<
int
>&
Ns
,
const
std
::
vector
<
int
>&
Ks
,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
const
std
::
vector
<
int
>&
kbatches
)
{
bool
pass
=
ck
::
profiler
::
profile_grouped_gemm_impl
<
ADataType
,
BDataType
,
...
...
@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
EXPECT_TRUE
(
pass
);
}
};
template
<
typename
Tuple
>
class
TestGroupedGemmTwoStage
:
public
testing
::
TestWithParam
<
int
>
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
ELayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
EDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
public:
static
constexpr
bool
verify_
=
true
;
static
constexpr
int
init_method_
=
1
;
// decimal value initialization
static
constexpr
bool
log_
=
false
;
static
constexpr
bool
bench_
=
false
;
// measure kernel performance
void
SetUp
()
override
{}
void
Run
(
const
std
::
vector
<
int
>&
Ms
,
const
std
::
vector
<
int
>&
Ns
,
const
std
::
vector
<
int
>&
Ks
,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
,
int
n_warmup
=
1
,
int
n_iter
=
10
)
{
bool
pass
=
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ADataType
,
BDataType
,
EDataType
,
float
,
ALayout
,
BLayout
,
ELayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
kbatches
,
n_warmup_
,
n_iter_
);
EXPECT_TRUE
(
pass
);
}
};
...
...
@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
if
(
kbatch
>
1
)
{
ggemm_instance
.
SetKBatchSize
(
argument
,
kbatch
);
ggemm_instance
.
SetKBatchSize
(
&
argument
,
kbatch
);
}
return
ggemm_instance
.
IsSupportedArgument
(
argument
);
...
...
@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
if
(
kbatch
>
1
)
{
ggemm_instance
.
SetKBatchSize
(
argument
,
kbatch
);
ggemm_instance
.
SetKBatchSize
(
&
argument
,
kbatch
);
}
EXPECT_TRUE
(
ggemm_instance
.
IsSupportedArgument
(
argument
));
auto
invoker
=
ggemm_instance
.
MakeInvoker
();
DeviceMem
gemm_
desc_workspace
(
ggemm_instance
.
Get
WorkSpace
Size
(
&
argument
));
ggemm_instance
.
Set
WorkSpacePointer
(
&
argument
,
gemm_
desc_workspace
.
GetDeviceBuffer
());
DeviceMem
dev_
gemm_
kargs
(
ggemm_instance
.
Get
DeviceKernelArg
Size
(
&
argument
));
ggemm_instance
.
Set
DeviceKernelArgs
(
&
argument
,
dev_
gemm_
kargs
.
GetDeviceBuffer
());
return
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
}
};
...
...
Prev
1
…
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment