Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
129e58ae
Commit
129e58ae
authored
Jun 05, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
9bebfd42
cb0645be
Changes
188
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
290 additions
and
80 deletions
+290
-80
python/ck4inductor/util.py
python/ck4inductor/util.py
+7
-0
test/CMakeLists.txt
test/CMakeLists.txt
+32
-4
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+6
-13
test/grouped_gemm/CMakeLists.txt
test/grouped_gemm/CMakeLists.txt
+6
-0
test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
...emm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
+62
-0
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
+61
-0
test/grouped_gemm/test_grouped_gemm_util.hpp
test/grouped_gemm/test_grouped_gemm_util.hpp
+54
-1
test/position_embedding/position_embedding.cpp
test/position_embedding/position_embedding.cpp
+62
-62
No files found.
python/ck4inductor/util.py
0 → 100644
View file @
129e58ae
import
functools
import
os
@
functools
.
lru_cache
(
None
)
def
library_path
():
return
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'library'
)
test/CMakeLists.txt
View file @
129e58ae
...
@@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME)
...
@@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
message
(
"removing dl test
${
source
}
"
)
...
@@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME)
...
@@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
set_property
(
TARGET
${
TEST_NAME
}
PROPERTY HIP_ARCHITECTURES
${
TEST_TARGETS
}
)
target_link_libraries
(
${
TEST_NAME
}
PRIVATE getopt::getopt
)
target_link_libraries
(
${
TEST_NAME
}
PRIVATE getopt::getopt
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
>
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
>
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
...
@@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME)
...
@@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
message
(
"removing dl test
${
source
}
"
)
...
@@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME)
...
@@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
set_property
(
TARGET
${
TEST_NAME
}
PROPERTY HIP_ARCHITECTURES
${
TEST_TARGETS
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
129e58ae
...
@@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
index_t
>
split_ks
{
1
,
2
};
std
::
vector
<
ck
::
index_t
>
split_ks
{
1
,
2
};
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
bool
skip_case
(
const
ck
::
index_t
split_k
)
{
{
// Odd K or C values are supported only by DL and WMMA
// kernels (only applies to fp16)
// DL and WMMA kernels currently support only `split_k=1`
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
{
if
(
split_k
!=
1
&&
(
params
.
K_
%
2
!=
0
||
params
.
C_
%
2
!=
0
))
{
return
true
;
}
}
// 1d NWGC is only supported by DL kernel
// 1d NWGC is only supported by DL kernel
// DL kernel is only supported for split_k=1
// DL kernel is only supported for split_k=1
if
constexpr
(
std
::
is_same_v
<
InLayout
,
NWGC
>
&&
std
::
is_same_v
<
OutLayout
,
NWGK
>
)
if
constexpr
(
std
::
is_same_v
<
InLayout
,
NWGC
>
&&
std
::
is_same_v
<
OutLayout
,
NWGK
>
)
...
@@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
{
{
for
(
auto
&
param
:
conv_params
)
for
(
auto
&
param
:
conv_params
)
{
{
if
(
!
skip_case
(
param
,
split_k
))
if
(
!
skip_case
(
split_k
))
{
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
{},
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
{},
InLayout
,
InLayout
,
...
@@ -189,6 +178,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
...
@@ -189,6 +178,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
16
,
16
,
1
,
1
,
{
3
,
3
},
{
28
,
28
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
();
this
->
Run
();
}
}
...
@@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
...
@@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
16
,
16
,
1
,
1
,
{
3
,
3
,
3
},
{
28
,
28
,
28
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
Run
();
this
->
Run
();
}
}
test/grouped_gemm/CMakeLists.txt
View file @
129e58ae
...
@@ -6,6 +6,12 @@ if(result EQUAL 0)
...
@@ -6,6 +6,12 @@ if(result EQUAL 0)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_splitk
)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_splitk
)
endif
()
endif
()
add_gtest_executable
(
test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance
)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_two_stage_splitk
)
endif
()
add_gtest_executable
(
test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp
)
add_gtest_executable
(
test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance
)
target_link_libraries
(
test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance
)
...
...
test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
RRR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
using
RRR_F16_F16_F16_LargeK
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16_LargeK
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
using
RRR_BF16_BF16_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
>>
;
using
RCR_BF16_BF16_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
>>
;
using
RRR_BF16_I8_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
BF16
,
I8
,
BF16
>>
;
using
RCR_BF16_I8_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
BF16
,
I8
,
BF16
>>
;
const
std
::
vector
<
int
>
KBATCH
{
1
,
2
,
3
,
5
,
8
};
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_KN
,
RRR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_NK
,
RCR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_KN_BF16
,
RRR_BF16_BF16_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_NK_BF16
,
RCR_BF16_BF16_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8
,
RRR_BF16_I8_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8
,
RCR_BF16_I8_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_LargeK_MK_KN
,
RRR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_LargeK_MK_NK
,
RCR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
#include "test_grouped_gemm_ut_cases.inc"
#include "test_grouped_gemm_two_stage_ut_cases.inc"
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
0 → 100644
View file @
129e58ae
#pragma once
TEST_P
(
RRR_BF16_BF16_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_BF16_BF16_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_BF16_I8_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_BF16_I8_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
test/grouped_gemm/test_grouped_gemm_util.hpp
View file @
129e58ae
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace
ck
{
namespace
ck
{
namespace
test
{
namespace
test
{
...
@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
...
@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
}
}
};
};
template
<
typename
Tuple
>
class
TestGroupedGemmTwoStage
:
public
testing
::
TestWithParam
<
int
>
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
ELayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
EDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
public:
static
constexpr
bool
verify_
=
true
;
static
constexpr
int
init_method_
=
1
;
// decimal value initialization
static
constexpr
bool
log_
=
false
;
static
constexpr
bool
bench_
=
false
;
// measure kernel performance
void
SetUp
()
override
{}
void
Run
(
const
std
::
vector
<
int
>&
Ms
,
const
std
::
vector
<
int
>&
Ns
,
const
std
::
vector
<
int
>&
Ks
,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
,
int
n_warmup
=
1
,
int
n_iter
=
10
)
{
bool
pass
=
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ADataType
,
BDataType
,
EDataType
,
float
,
ALayout
,
BLayout
,
ELayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
EXPECT_TRUE
(
pass
);
}
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
...
...
test/position_embedding/position_embedding.cpp
View file @
129e58ae
...
@@ -131,74 +131,74 @@ int main()
...
@@ -131,74 +131,74 @@ int main()
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
1
,
0
,
1
,
2
,
3
,
4
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
4
,
2
,
1
,
0
,
1
,
2
,
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
});
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
,
-
3
,
-
2
,
-
1
,
0
,
4
,
3
,
2
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
5
,
4
,
3
,
2
});
-
5
,
-
4
,
-
3
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
4
,
3
,
2
,
1
,
0
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
5
,
4
,
3
,
2
,
1
,
0
});
-
5
,
-
4
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
3
,
-
4
,
-
5
,
1
,
2
,
3
,
4
,
-
1
,
-
2
,
-
3
,
-
4
,
0
,
1
,
2
,
3
,
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
});
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
VERTICAL
,
{
0
,
1
,
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
VERTICAL
,
{
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
1
,
0
,
1
,
2
,
3
,
4
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
4
,
2
,
1
,
0
,
1
,
2
,
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
});
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
,
-
3
,
-
2
,
-
1
,
0
,
4
,
3
,
2
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
5
,
4
,
3
,
2
});
-
5
,
-
4
,
-
3
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
3
,
2
,
1
,
0
,
1
,
2
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
4
,
3
,
2
,
1
,
0
,
1
,
-
4
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
5
,
4
,
3
,
2
,
1
,
0
});
-
5
,
-
4
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
3
,
-
4
,
-
5
,
1
,
2
,
3
,
4
,
-
1
,
-
2
,
-
3
,
-
4
,
0
,
1
,
2
,
3
,
0
,
-
1
,
-
2
,
-
3
,
1
,
0
,
1
,
2
,
-
1
,
0
,
-
1
,
-
2
,
2
,
1
,
0
,
1
,
-
2
,
-
1
,
0
,
-
1
,
3
,
2
,
1
,
0
});
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
-
1
,
-
2
,
1
,
0
,
1
,
-
1
,
0
,
-
1
,
2
,
1
,
0
});
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
8
,
{
0.5
,
0.25
,
0.125
,
0.0625
,
0.03125
,
0.015625
,
0.0078125
,
0.00390625
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
8
,
{
0.5
,
0.25
,
0.125
,
0.0625
,
0.03125
,
0.015625
,
0.0078125
,
0.00390625
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
16
,
{
0.7071067811865476
,
0.5
,
0.35355339059327384
,
0.25000000000000006
,
0.17677669529663692
,
rtn
&=
test_alibi_slope_generation
<
float
>
(
16
,
{
0.7071067811865476
,
0.5
,
0.35355339059327384
,
0.25000000000000006
,
0.17677669529663692
,
...
...
Prev
1
…
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment