Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5029a5a4
Commit
5029a5a4
authored
Jul 03, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
5ec6a912
95907384
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
63 additions
and
13 deletions
+63
-13
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+28
-0
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+2
-2
test/CMakeLists.txt
test/CMakeLists.txt
+5
-1
test/grouped_convnd_bwd_data/CMakeLists.txt
test/grouped_convnd_bwd_data/CMakeLists.txt
+4
-4
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
..._bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
+8
-0
test/grouped_convnd_bwd_weight/CMakeLists.txt
test/grouped_convnd_bwd_weight/CMakeLists.txt
+4
-4
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
..._weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
+8
-0
test/grouped_convnd_fwd/CMakeLists.txt
test/grouped_convnd_fwd/CMakeLists.txt
+1
-1
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+3
-1
No files found.
include/ck/utility/amd_smfmac.hpp
View file @
5029a5a4
...
@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
...
@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
__device__
static
void
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
}
};
};
...
@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
...
@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
__device__
static
void
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
}
};
};
...
@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
...
@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
__device__
static
void
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
}
};
};
...
@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
...
@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
__device__
static
void
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
}
};
};
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
5029a5a4
...
@@ -98,8 +98,8 @@ int profile_grouped_gemm(int argc, char* argv[])
...
@@ -98,8 +98,8 @@ int profile_grouped_gemm(int argc, char* argv[])
int
n_iter
=
10
;
int
n_iter
=
10
;
if
(
argc
==
17
)
if
(
argc
==
17
)
{
{
n_warmup
=
std
::
stoi
(
argv
[
1
6
]);
n_warmup
=
std
::
stoi
(
argv
[
1
5
]);
n_iter
=
std
::
stoi
(
argv
[
1
7
]);
n_iter
=
std
::
stoi
(
argv
[
1
6
]);
}
}
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
...
...
test/CMakeLists.txt
View file @
5029a5a4
...
@@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME)
...
@@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a
)
endif
()
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
@@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME)
...
@@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a
)
endif
()
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
@@ -209,7 +213,7 @@ add_subdirectory(wrapper)
...
@@ -209,7 +213,7 @@ add_subdirectory(wrapper)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
if
(
GPU_TARGETS MATCHES
"gfx942"
)
if
(
GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
add_subdirectory
(
smfmac_op
)
endif
()
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
position_embedding
)
test/grouped_convnd_bwd_data/CMakeLists.txt
View file @
5029a5a4
...
@@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x
...
@@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
endif
()
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface
_xdl
test_grouped_convnd_bwd_data_interface_xdl.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface
_xdl
PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
endif
()
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface
_wmma
test_grouped_convnd_bwd_data_interface_wmma.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface
_wmma
PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
endif
()
endif
()
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
View file @
5029a5a4
...
@@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test
...
@@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test
ck
::
utils
::
conv
::
ConvParam
conv_param
;
ck
::
utils
::
conv
::
ConvParam
conv_param
;
void
SetUp
()
override
{
if
(
!
ck
::
is_gfx11_supported
())
{
GTEST_SKIP
();
}
}
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
bool
Run
()
bool
Run
()
{
{
...
...
test/grouped_convnd_bwd_weight/CMakeLists.txt
View file @
5029a5a4
...
@@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS)
...
@@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance
)
endif
()
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface
_xdl
test_grouped_convnd_bwd_weight_interface_xdl.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface
_xdl
PRIVATE utility
)
endif
()
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_weight_interface
_wmma
test_grouped_convnd_bwd_weight_interface_wmma.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface PRIVATE utility
)
target_link_libraries
(
test_grouped_convnd_bwd_weight_interface
_wmma
PRIVATE utility
)
endif
()
endif
()
add_gtest_executable
(
test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp
)
add_gtest_executable
(
test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp
View file @
5029a5a4
...
@@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck
::
utils
::
conv
::
ConvParam
conv_param
;
ck
::
utils
::
conv
::
ConvParam
conv_param
;
void
SetUp
()
override
{
if
(
!
ck
::
is_gfx11_supported
())
{
GTEST_SKIP
();
}
}
template
<
ck
::
index_t
SplitK
>
template
<
ck
::
index_t
SplitK
>
bool
Run
()
bool
Run
()
{
{
...
...
test/grouped_convnd_fwd/CMakeLists.txt
View file @
5029a5a4
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
)
add_gtest_executable
(
test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp
)
add_gtest_executable
(
test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
(
GPU_TARGETS MATCHES
"gfx11"
)
AND
(
NOT GPU_TARGETS MATCHES
"gfx9"
))
target_link_libraries
(
test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance
)
target_link_libraries
(
test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance
)
else
()
else
()
target_link_libraries
(
test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance
)
target_link_libraries
(
test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance
)
...
...
test/wmma_op/wmma_op_util.hpp
View file @
5029a5a4
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp"
#include "ck/utility/amd_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace
ck
{
namespace
ck
{
namespace
wmma_op_util
{
namespace
wmma_op_util
{
...
@@ -373,7 +374,8 @@ struct TestWmma
...
@@ -373,7 +374,8 @@ struct TestWmma
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
// Act
bool
is_supported
=
ck
::
wmma_op_util
::
RunDeviceGEMM
(
wmma_kernel
,
a
,
b
,
c_device
);
bool
is_supported
=
ck
::
is_gfx11_supported
()
&&
ck
::
wmma_op_util
::
RunDeviceGEMM
(
wmma_kernel
,
a
,
b
,
c_device
);
if
(
is_supported
)
if
(
is_supported
)
{
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment