Commit 5029a5a4 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 5ec6a912 95907384
...@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16> ...@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
__device__ static void __device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
} }
}; };
...@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16> ...@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
__device__ static void __device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
} }
}; };
...@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32> ...@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
__device__ static void __device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
} }
}; };
...@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32> ...@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
__device__ static void __device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{ {
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0); reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
} }
}; };
......
...@@ -98,8 +98,8 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -98,8 +98,8 @@ int profile_grouped_gemm(int argc, char* argv[])
int n_iter = 10; int n_iter = 10;
if(argc == 17) if(argc == 17)
{ {
n_warmup = std::stoi(argv[16]); n_warmup = std::stoi(argv[15]);
n_iter = std::stoi(argv[17]); n_iter = std::stoi(argv[16]);
} }
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
......
...@@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME) ...@@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME)
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a)
endif() endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
...@@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME) ...@@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME)
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "_smfmac")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a)
endif() endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
...@@ -209,7 +213,7 @@ add_subdirectory(wrapper) ...@@ -209,7 +213,7 @@ add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
if(GPU_TARGETS MATCHES "gfx942") if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
add_subdirectory(smfmac_op) add_subdirectory(smfmac_op)
endif() endif()
add_subdirectory(position_embedding) add_subdirectory(position_embedding)
...@@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x ...@@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
endif() endif()
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data_interface_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance)
endif() endif()
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance)
endif() endif()
...@@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test ...@@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test
ck::utils::conv::ConvParam conv_param; ck::utils::conv::ConvParam conv_param;
void SetUp() override
{
if(!ck::is_gfx11_supported())
{
GTEST_SKIP();
}
}
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
bool Run() bool Run()
{ {
......
...@@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS) ...@@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance)
endif() endif()
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility)
endif() endif()
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_convnd_bwd_weight_interface_wmma.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility)
endif() endif()
add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp) add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp)
if(result EQUAL 0) if(result EQUAL 0)
......
...@@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck::utils::conv::ConvParam conv_param; ck::utils::conv::ConvParam conv_param;
void SetUp() override
{
if(!ck::is_gfx11_supported())
{
GTEST_SKIP();
}
}
template <ck::index_t SplitK> template <ck::index_t SplitK>
bool Run() bool Run()
{ {
......
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11")
add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp)
if(GPU_TARGETS MATCHES "gfx11") if((GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
else() else()
target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance)
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp" #include "ck/utility/amd_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck { namespace ck {
namespace wmma_op_util { namespace wmma_op_util {
...@@ -373,7 +374,8 @@ struct TestWmma ...@@ -373,7 +374,8 @@ struct TestWmma
a, b, c_host, a_element_op, b_element_op, c_element_op); a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act // Act
bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); bool is_supported = ck::is_gfx11_supported() &&
ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
if(is_supported) if(is_supported)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment