Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c79316e2
Commit
c79316e2
authored
Jun 05, 2023
by
Bartlomiej Kocot
Committed by
Bartłomiej Kocot
Jun 09, 2023
Browse files
Fix comments
parent
e6334634
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
113 additions
and
268 deletions
+113
-268
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+3
-5
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
...ry/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
+16
-8
profiler/include/profiler/profile_batched_gemm_impl.hpp
profiler/include/profiler/profile_batched_gemm_impl.hpp
+1
-0
test/batched_gemm_multi_d/CMakeLists.txt
test/batched_gemm_multi_d/CMakeLists.txt
+2
-7
test/batched_gemm_multi_d/batched_gemm_multi_d_fp16.cpp
test/batched_gemm_multi_d/batched_gemm_multi_d_fp16.cpp
+0
-124
test/batched_gemm_multi_d/batched_gemm_multi_d_int8.cpp
test/batched_gemm_multi_d/batched_gemm_multi_d_int8.cpp
+0
-124
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
+91
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
c79316e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c)
2018-
2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -201,8 +201,6 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
...
...
@@ -240,8 +238,6 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
...
...
@@ -649,6 +645,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
assert
(
arg
.
K
%
K1
==
0
);
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx1100"
||
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp
View file @
c79316e2
...
...
@@ -269,25 +269,29 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
...
...
@@ -297,25 +301,29 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances
(
op_ptrs
);
add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances
(
op_ptrs
);
}
}
...
...
profiler/include/profiler/profile_batched_gemm_impl.hpp
View file @
c79316e2
...
...
@@ -141,6 +141,7 @@ bool profile_batched_gemm_impl(int do_verification,
for
(
auto
&
op_ptr
:
op_ptrs
)
{
std
::
unique_ptr
<
tensor_operation
::
device
::
BaseArgument
>
argument_ptr
;
// true branch for multi d dl kernel
if
constexpr
(
std
::
is_same
<
DeviceOp
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemm
<
ALayout
,
...
...
test/batched_gemm_multi_d/CMakeLists.txt
View file @
c79316e2
add_test_executable
(
test_batched_gemm_multi_d_fp16 batched_gemm_multi_d_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_multi_d_fp16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_multi_d_fp16 PRIVATE device_batched_gemm_multi_d_instance
)
add_test_executable
(
test_batched_gemm_multi_d_int8 batched_gemm_multi_d_int8.cpp
)
target_link_libraries
(
test_batched_gemm_multi_d_int8 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_multi_d_int8 PRIVATE device_batched_gemm_multi_d_instance
)
add_gtest_executable
(
test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp
)
target_link_libraries
(
test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance
)
test/batched_gemm_multi_d/batched_gemm_multi_d_fp16.cpp
deleted
100644 → 0
View file @
e6334634
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp"
namespace
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
512
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Row
,
Row
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Row
,
Col
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Col
,
Row
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Col
,
Col
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMMMultiD fp16: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm_multi_d/batched_gemm_multi_d_int8.cpp
deleted
100644 → 0
View file @
e6334634
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp"
namespace
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
int
main
()
{
int
M
=
256
;
int
N
=
256
;
int
K
=
128
;
int
BatchCount
=
3
;
bool
pass
=
true
;
using
namespace
ck
::
tensor_operation
::
device
;
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Row
,
Row
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Row
,
Col
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Col
,
Row
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
pass
=
pass
&&
ck
::
profiler
::
profile_batched_gemm_impl
<
ADataType
,
BDataType
,
CDataType
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
Col
,
Col
,
Empty_Tuple
,
Row
,
ADataType
,
BDataType
,
Empty_Tuple
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
M
,
K
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
std
::
cout
<<
"test BatchedGEMMMultiD int8: "
<<
(
pass
?
"Pass"
:
"Fail"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp
0 → 100644
View file @
c79316e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp"
namespace
{
using
F16
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
}
// namespace
template
<
typename
Tuple
>
class
TestBatchedGemmMultiD
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
DataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
static
constexpr
int
M
=
512
;
static
constexpr
int
N
=
256
;
static
constexpr
int
K
=
128
;
static
constexpr
int
BatchCount
=
3
;
void
Run
()
{
using
namespace
ck
::
tensor_operation
::
device
;
const
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_impl
<
DataType
,
DataType
,
DataType
,
ALayout
,
BLayout
,
CLayout
,
PassThrough
,
PassThrough
,
PassThrough
,
DeviceBatchedGemmMultiD
<
ALayout
,
BLayout
,
Empty_Tuple
,
CLayout
,
DataType
,
DataType
,
Empty_Tuple
,
DataType
,
PassThrough
,
PassThrough
,
PassThrough
>>
(
true
,
1
,
false
,
1
,
M
,
N
,
K
,
K
,
N
,
N
,
M
*
K
,
K
*
N
,
M
*
N
,
BatchCount
);
EXPECT_TRUE
(
pass
);
}
};
template
<
typename
Tuple
>
class
TestBatchedGemmMultiDF16
:
public
TestBatchedGemmMultiD
<
Tuple
>
{
};
template
<
typename
Tuple
>
class
TestBatchedGemmMultiDI8
:
public
TestBatchedGemmMultiD
<
Tuple
>
{
};
using
F16KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
>>
;
using
I8KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
Row
,
Row
,
Row
,
int8_t
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
int8_t
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
int8_t
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
int8_t
>>
;
TYPED_TEST_SUITE
(
TestBatchedGemmMultiDF16
,
F16KernelTypes
);
TYPED_TEST_SUITE
(
TestBatchedGemmMultiDI8
,
I8KernelTypes
);
TYPED_TEST
(
TestBatchedGemmMultiDF16
,
bilinear
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmMultiDI8
,
scale
)
{
this
->
Run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment