Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3782ed3b
Commit
3782ed3b
authored
Jun 12, 2023
by
Bartlomiej Kocot
Browse files
reproducer
parent
0f48e38a
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
244 additions
and
258 deletions
+244
-258
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
...ry/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
+203
-203
library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt
...peration_instance/gpu/batched_gemm_multi_d/CMakeLists.txt
+1
-16
library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
...ched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
+28
-28
profiler/include/profiler/profile_batched_gemm_impl.hpp
profiler/include/profiler/profile_batched_gemm_impl.hpp
+2
-2
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
+10
-9
No files found.
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
View file @
3782ed3b
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt
View file @
3782ed3b
add_instance_library
(
device_batched_gemm_multi_d_instance
add_instance_library
(
device_batched_gemm_multi_d_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
View file @
3782ed3b
This diff is collapsed.
Click to expand it.
profiler/include/profiler/profile_batched_gemm_impl.hpp
View file @
3782ed3b
...
@@ -86,8 +86,8 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -86,8 +86,8 @@ bool profile_batched_gemm_impl(int do_verification,
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
1
,
2
});
b_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
1
,
2
});
break
;
break
;
default:
default:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
...
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
View file @
3782ed3b
...
@@ -25,10 +25,10 @@ class TestBatchedGemmMultiD : public ::testing::Test
...
@@ -25,10 +25,10 @@ class TestBatchedGemmMultiD : public ::testing::Test
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
static
constexpr
int
M
=
512
;
static
constexpr
int
M
=
64
;
static
constexpr
int
N
=
256
;
static
constexpr
int
N
=
8
;
static
constexpr
int
K
=
128
;
static
constexpr
int
K
=
64
;
static
constexpr
int
BatchCount
=
3
;
static
constexpr
int
BatchCount
=
1
;
template
<
typename
DataType
>
template
<
typename
DataType
>
void
Run
()
void
Run
()
...
@@ -61,14 +61,15 @@ class TestBatchedGemmMultiD : public ::testing::Test
...
@@ -61,14 +61,15 @@ class TestBatchedGemmMultiD : public ::testing::Test
}
}
};
};
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
Row
,
Row
,
Row
>
,
using
KernelTypes
=
::
testing
::
Types
<
//std::tuple<Row, Row, Row>,
std
::
tuple
<
Row
,
Col
,
Row
>
,
std
::
tuple
<
Row
,
Col
,
Row
>
std
::
tuple
<
Col
,
Row
,
Row
>
,
// std::tuple<Col, Row, Row>,
std
::
tuple
<
Col
,
Col
,
Row
>>
;
// std::tuple<Col, Col, Row>
>
;
}
// namespace
}
// namespace
TYPED_TEST_SUITE
(
TestBatchedGemmMultiD
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestBatchedGemmMultiD
,
KernelTypes
);
TYPED_TEST
(
TestBatchedGemmMultiD
,
f16
)
{
this
->
template
Run
<
F16
>();
}
TYPED_TEST
(
TestBatchedGemmMultiD
,
f16
)
{
this
->
template
Run
<
F16
>();
}
TYPED_TEST
(
TestBatchedGemmMultiD
,
int8
)
{
this
->
template
Run
<
int8_t
>();
}
//
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment