Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
742dd3aa
Commit
742dd3aa
authored
Sep 26, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
c9013009
1f02eaef
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
72 additions
and
26 deletions
+72
-26
CMakeLists.txt
CMakeLists.txt
+10
-5
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
...k/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
+11
-9
test/batchnorm/batchnorm_bwd_rank_4.cpp
test/batchnorm/batchnorm_bwd_rank_4.cpp
+17
-4
test/batchnorm/batchnorm_fwd_rank_4.cpp
test/batchnorm/batchnorm_fwd_rank_4.cpp
+17
-4
test/batchnorm/batchnorm_infer_rank_4.cpp
test/batchnorm/batchnorm_infer_rank_4.cpp
+17
-4
No files found.
CMakeLists.txt
View file @
742dd3aa
cmake_minimum_required
(
VERSION 3.14
)
cmake_policy
(
SET CMP0140 NEW
)
if
(
POLICY CMP0140
)
# policies CMP0140 not known to CMake until 3.25
cmake_policy
(
SET CMP0140 NEW
)
endif
()
# This has to be initialized before the project() command appears
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
...
...
@@ -108,16 +111,18 @@ else()
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
if
(
GPU_TARGETS
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9, gfx10, or gfx11"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9
0, gfx94
, gfx10, or gfx11"
)
endif
()
if
(
GPU_ARCH MATCHES
"gfx9"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942"
)
if
(
GPU_ARCH MATCHES
"gfx90"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx900;gfx906;gfx908;gfx90a"
)
elseif
(
GPU_ARCH MATCHES
"gfx94"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx940;gfx941;gfx942"
)
elseif
(
GPU_ARCH MATCHES
"gfx10"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1030"
)
elseif
(
GPU_ARCH MATCHES
"gfx11"
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS TARGETS
"gfx1100;gfx1101;gfx1102"
)
else
()
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx9, gfx10, or gfx11"
)
message
(
FATAL_ERROR
"For PROFILE_ONLY build, please specify GPU_ARCH as gfx9
0, gfx94
, gfx10, or gfx11"
)
endif
()
set
(
GPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
endif
()
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
View file @
742dd3aa
...
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#ifdef CK_ENABLE_FP16
void
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
...
...
@@ -68,7 +68,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
...
...
@@ -120,7 +121,7 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
#endif
// GEMM + Bilinear
template
<
typename
ALayout
,
typename
BLayout
,
...
...
@@ -158,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
DDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
...
...
@@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
std
::
int8_t
>
&&
is_same_v
<
BDataType
,
std
::
int8_t
>
&&
is_same_v
<
DDataType
,
std
::
int8_t
>
&&
is_same_v
<
EDataType
,
std
::
int8_t
>
)
#endif
#ifdef CK_ENABLE_INT8
if
constexpr
(
is_same_v
<
ADataType
,
std
::
int8_t
>
&&
is_same_v
<
BDataType
,
std
::
int8_t
>
&&
is_same_v
<
DDataType
,
std
::
int8_t
>
&&
is_same_v
<
EDataType
,
std
::
int8_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
DLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
...
...
@@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
...
...
@@ -220,4 +223,3 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
test/batchnorm/batchnorm_bwd_rank_4.cpp
View file @
742dd3aa
...
...
@@ -70,10 +70,23 @@ class TestBatchNormBwdRank4 : public ::testing::Test
}
};
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F32
,
F32
,
F32
,
F16
,
F32
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
BF16
,
F32
,
F32
,
F32
,
BF16
,
F32
,
F32
>
,
std
::
tuple
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
F64
>>
;
using
KernelTypes
=
::
testing
::
Types
<
#ifdef CK_ENABLE_FP16
std
::
tuple
<
F16
,
F32
,
F32
,
F32
,
F16
,
F32
,
F32
>
#endif
#ifdef CK_ENABLE_FP32
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
F32
>
#endif
#ifdef CK_ENABLE_BF16
,
std
::
tuple
<
BF16
,
F32
,
F32
,
F32
,
BF16
,
F32
,
F32
>
#endif
#ifdef CK_ENABLE_FP64
,
std
::
tuple
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
F64
>
#endif
>
;
TYPED_TEST_SUITE
(
TestBatchNormBwdRank4
,
KernelTypes
);
...
...
test/batchnorm/batchnorm_fwd_rank_4.cpp
View file @
742dd3aa
...
...
@@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test
}
};
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
>
,
std
::
tuple
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
>>
;
using
KernelTypes
=
::
testing
::
Types
<
#ifdef CK_ENABLE_FP16
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
>
#endif
#ifdef CK_ENABLE_FP32
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
>
#endif
#ifdef CK_ENABLE_BF16
,
std
::
tuple
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
>
#endif
#ifdef CK_ENABLE_FP64
,
std
::
tuple
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
>
#endif
>
;
TYPED_TEST_SUITE
(
TestBatchNormFwdRank4
,
KernelTypes
);
...
...
test/batchnorm/batchnorm_infer_rank_4.cpp
View file @
742dd3aa
...
...
@@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test
}
};
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
>
,
std
::
tuple
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
>
,
std
::
tuple
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
>>
;
using
KernelTypes
=
::
testing
::
Types
<
#ifdef CK_ENABLE_FP16
std
::
tuple
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
>
#endif
#ifdef CK_ENABLE_FP32
,
std
::
tuple
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
>
#endif
#ifdef CK_ENABLE_BF16
,
std
::
tuple
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
>
#endif
#ifdef CK_ENABLE_FP64
,
std
::
tuple
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
>
#endif
>
;
TYPED_TEST_SUITE
(
TestBatchNormInferRank4
,
KernelTypes
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment