Commit 75640f22 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/test_ggemm_splitk

parents a627599c d821d1e5
...@@ -135,7 +135,8 @@ __global__ void ...@@ -135,7 +135,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__)) defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -711,7 +712,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -711,7 +712,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx940")) ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102"))
{ {
return false; return false;
} }
......
...@@ -106,7 +106,8 @@ __global__ void ...@@ -106,7 +106,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -600,7 +601,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -600,7 +601,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{ {
return false; return false;
} }
......
...@@ -1393,7 +1393,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1393,7 +1393,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{ {
return false; return false;
} }
......
...@@ -485,7 +485,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -485,7 +485,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030") if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
...@@ -51,7 +51,8 @@ __global__ void ...@@ -51,7 +51,8 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__)) defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
...@@ -553,7 +554,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -553,7 +554,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
{ {
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx940") ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
...@@ -1027,7 +1027,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1027,7 +1027,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{ {
return false; return false;
} }
......
...@@ -40,7 +40,8 @@ __global__ void ...@@ -40,7 +40,8 @@ __global__ void
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__)) defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__) || defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
......
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) target_link_libraries(test_batched_gemm_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) target_link_libraries(test_batched_gemm_fp32 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) target_link_libraries(test_batched_gemm_bf16 PRIVATE utility)
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility) target_link_libraries(test_batched_gemm_int8 PRIVATE utility)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
endif()
\ No newline at end of file
add_custom_target(test_batched_gemm_gemm) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_batched_gemm_gemm)
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
\ No newline at end of file endif()
\ No newline at end of file
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance)
endif()
add_custom_target(test_batched_gemm_softmax_gemm) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_batched_gemm_softmax_gemm)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
\ No newline at end of file endif()
\ No newline at end of file
add_custom_target(test_batched_gemm_softmax_gemm_permute) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
\ No newline at end of file endif()
\ No newline at end of file
add_gtest_executable(test_contraction test_contraction.cpp) add_gtest_executable(test_contraction test_contraction.cpp)
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
endif()
add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp)
target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance)
endif()
\ No newline at end of file
add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) add_gtest_executable(test_convnd_fwd convnd_fwd.cpp)
target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance)
endif()
add_custom_target(test_gemm_layernorm) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_gemm_layernorm)
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
endif()
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp)
target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
endif()
\ No newline at end of file
add_custom_target(test_grouped_gemm) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_grouped_gemm)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface)
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) endif()
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment