"host/host_tensor/src/device.cpp" did not exist on "1566b31736d191fe3a43dd5efa59968e44191729"
Unverified Commit 3049b546 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

[GEMM] gemm_universal related optimization (#1453)



* replace buffer_atomic with global_atomic

* fixed global_atomic_add

* added bf16 atomic_add

* format

* clang-format-12

* clean

* clean

* add guards

* Update gtest.cmake

* enabled splitk_gemm_multi_d

* format

* add ckProfiler

* format

* fixed naming

* format

* clean

* clean

* add guards

* fix clang format

* format

* add kbatch printout

* clean

* Add rocm6.2 related gemm optimization

* Limit bf16 atomic usage

* remove redundant RCR gemm_universal instance

* Add RRR fp8 gemm universal instance

* Bug fix

* Add GPU_TARGET guard to FP8/BF8 target

* bug fix

* update cmake

* remove all fp8/bf8 example if arch not support

* Enable fp8 RRR support in ckProfiler

* limit greedy-reverse flag to gemm_universal in ckProfiler

---------
Co-authored-by: default avatarJing Zhang <jizhan@fb.com>
Co-authored-by: default avatarJing Zhang <jizhan@meta.com>
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
parent 50c42348
...@@ -62,8 +62,17 @@ if (DTYPES) ...@@ -62,8 +62,17 @@ if (DTYPES)
endif() endif()
message("DTYPES macro set to ${DTYPES}") message("DTYPES macro set to ${DTYPES}")
else() else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON") set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
if (GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
set(CK_ENABLE_FP8 "ON")
set(CK_ENABLE_BF8 "ON")
endif()
endif() endif()
#for f8/bf8_t type #for f8/bf8_t type
......
...@@ -5,17 +5,17 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -5,17 +5,17 @@ if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp) add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp) add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp) add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
......
...@@ -4,5 +4,7 @@ target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel:: ...@@ -4,5 +4,7 @@ target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::
add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp) add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp)
target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_conv_operations) add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp)
target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_conv_operations)
endif()
\ No newline at end of file
...@@ -2,10 +2,13 @@ add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_f ...@@ -2,10 +2,13 @@ add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_f
add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp) add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp)
add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp) add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp)
add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp) add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp)
add_executable(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp)
target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
add_executable(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
endif()
\ No newline at end of file
...@@ -4,7 +4,7 @@ if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) ...@@ -4,7 +4,7 @@ if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
endif() endif()
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations)
endif() endif()
......
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)
endif() endif()
...@@ -34,8 +34,17 @@ if (DTYPES) ...@@ -34,8 +34,17 @@ if (DTYPES)
endif() endif()
message("DTYPES macro set to ${DTYPES}") message("DTYPES macro set to ${DTYPES}")
else() else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_ALL_DTYPES "ON") set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
if (GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
set(CK_ENABLE_FP8 "ON")
set(CK_ENABLE_BF8 "ON")
endif()
endif() endif()
if (GPU_TARGETS) if (GPU_TARGETS)
......
...@@ -23,12 +23,8 @@ ...@@ -23,12 +23,8 @@
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
#ifdef CK_ENABLE_FP8 using F8 = ck::f8_t;
using F8 = ck::f8_t; using BF8 = ck::bf8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -208,6 +208,7 @@ int main(int argc, char* argv[]) ...@@ -208,6 +208,7 @@ int main(int argc, char* argv[])
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{StrideD, StrideD}, std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE, StrideE,
1,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
......
...@@ -69,7 +69,7 @@ using AElementOp = PassThrough; ...@@ -69,7 +69,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply; using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off // clang-format off
...@@ -99,6 +99,8 @@ int main(int argc, char* argv[]) ...@@ -99,6 +99,8 @@ int main(int argc, char* argv[])
ck::index_t StrideD = 0; ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -109,7 +111,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +111,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 12)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -123,13 +125,16 @@ int main(int argc, char* argv[]) ...@@ -123,13 +125,16 @@ int main(int argc, char* argv[])
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]); StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]); StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
exit(0); exit(0);
} }
...@@ -212,6 +217,7 @@ int main(int argc, char* argv[]) ...@@ -212,6 +217,7 @@ int main(int argc, char* argv[])
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0}, std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE, StrideE,
KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
...@@ -236,10 +242,12 @@ int main(int argc, char* argv[]) ...@@ -236,10 +242,12 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
invoker.Run(argument, StreamConfig{nullptr, false});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({M, N}); Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
......
...@@ -72,6 +72,20 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -72,6 +72,20 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
endforeach() endforeach()
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
message("removing fp8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
message("removing bf8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
......
...@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported() ...@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
} }
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
}
inline bool is_gfx101_supported() inline bool is_gfx101_supported()
{ {
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
......
...@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -69,17 +69,17 @@ template <typename ALayout, ...@@ -69,17 +69,17 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -192,15 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -192,15 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(arg_.KBatch > 1)
{ hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
if(arg_.KBatch > 1) 0,
hipGetErrorString( arg_.M * arg_.N * sizeof(CDataType),
hipMemsetAsync(arg_.p_c_grid, stream_config.stream_id_));
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -234,38 +230,161 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -234,38 +230,161 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
// Tail number could be One to Seven // Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::One>;
Run(kernel); Run(kernel);
} }
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full) TailNumber::Full)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Full>;
Run(kernel); Run(kernel);
} }
...@@ -273,12 +392,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -273,12 +392,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Two>; TailNumber::Two>;
Run(kernel); Run(kernel);
} }
} }
...@@ -288,12 +407,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -288,12 +407,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three) TailNumber::Three)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Three>; TailNumber::Three>;
Run(kernel); Run(kernel);
} }
} }
...@@ -303,12 +422,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -303,12 +422,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four) TailNumber::Four)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Four>; TailNumber::Four>;
Run(kernel); Run(kernel);
} }
} }
...@@ -318,12 +437,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -318,12 +437,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five) TailNumber::Five)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Five>; TailNumber::Five>;
Run(kernel); Run(kernel);
} }
} }
...@@ -332,12 +451,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -332,12 +451,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Six>; TailNumber::Six>;
Run(kernel); Run(kernel);
} }
} }
...@@ -347,12 +466,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -347,12 +466,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven) TailNumber::Seven)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Seven>; TailNumber::Seven>;
Run(kernel); Run(kernel);
} }
} }
...@@ -361,51 +480,98 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -361,51 +480,98 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even // Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel); Run(kernel);
} }
} }
} }
else else
{ {
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
} }
} }
...@@ -416,12 +582,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -416,12 +582,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
false, false,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -451,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -451,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
return false; return false;
} }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
...@@ -479,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -479,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
...@@ -494,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -494,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
1, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op};
...@@ -514,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -514,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs, std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
...@@ -529,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -529,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
1, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
......
...@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(arg_.KBatch > 1)
{ hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
if(arg_.KBatch > 1) 0,
hipGetErrorString( arg_.M * arg_.N * sizeof(CDataType),
hipMemsetAsync(arg_.p_c_grid, stream_config.stream_id_));
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -190,14 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -190,14 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
} }
else else
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(arg.KBatch > 1)
{ hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
if(arg.KBatch > 1) 0,
hipGetErrorString(hipMemsetAsync(arg.p_c_grid, arg.M * arg.N * sizeof(CDataType),
0, stream_config.stream_id_));
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
...@@ -215,15 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -215,15 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = true,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy>;
Run(kernel);
}
} }
else else
{ {
...@@ -240,118 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -240,118 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::Two>;
Run(kernel); Run(kernel);
} }
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == }
TailNumber::Full)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Three>;
Run(kernel); Run(kernel);
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two) TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Four) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Four>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Five) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Five>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Six) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Six>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Seven) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Seven>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
} }
} }
} }
...@@ -473,28 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -473,28 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Odd>;
minimum_occupancy, Run(kernel);
TailNumber::Odd>; }
Run(kernel); else
} {
else const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Even>;
minimum_occupancy, Run(kernel);
TailNumber::Even>;
Run(kernel);
}
} }
} }
else else
...@@ -525,28 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -525,28 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Odd>;
minimum_occupancy, Run(kernel);
TailNumber::Odd>; }
Run(kernel); else
} {
else const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Even>;
minimum_occupancy, Run(kernel);
TailNumber::Even>;
Run(kernel);
}
} }
} }
else else
...@@ -579,18 +558,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -579,18 +558,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = false,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
false, minimum_occupancy>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy>;
Run(kernel);
}
} }
else else
{ {
...@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false; return false;
} }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
......
...@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
}(); }();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// not pad M or N // not pad M or N
return c_grid_desc_mraw_nraw; return c_grid_desc_mraw_nraw;
} }
#endif
} }
struct Problem struct Problem
...@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
...@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
...@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value || if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value)) is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
{ {
if(!karg.IsReduceAdd()) if(!karg.IsReduceAdd())
{ {
......
...@@ -36,10 +36,9 @@ __global__ void ...@@ -36,10 +36,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
...@@ -56,7 +55,7 @@ __global__ void ...@@ -56,7 +55,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -69,10 +68,9 @@ __global__ void ...@@ -69,10 +68,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -93,7 +91,7 @@ __global__ void ...@@ -93,7 +91,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename ALayout, template <typename ALayout,
...@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
} }
}(); }();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// not pad M or N // not pad M or N
return c_grid_desc_mraw_nraw; return c_grid_desc_mraw_nraw;
} }
#endif
} }
__host__ __device__ static auto MakeDsGridDescriptor_M_N( __host__ __device__ static auto MakeDsGridDescriptor_M_N(
...@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
...@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
......
...@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset); dst_wave_addr_offset);
} }
template <typename T, index_t N>
__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
T* addr)
{
static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, half_t>::value)
{
vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
tmp.template AsType<half2_t>()[i]);
});
}
#if defined(__gfx942__)
else if constexpr(is_same<T, bhalf_t>::value)
{
vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
tmp.template AsType<bhalf2_t>()[i]);
});
}
#endif
}
template <typename T, index_t N> template <typename T, index_t N>
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data, __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
...@@ -907,18 +935,29 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -907,18 +935,29 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
using scalar_t = typename scalar_type<vector_t>::type; using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
if constexpr(is_same<T, bhalf_t>::value)
{
if(dst_thread_element_valid)
{
amd_global_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, p_dst_wave + dst_thread_element_offset);
}
}
else
{
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>( amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
} #else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif #endif
}
} }
// buffer_atomic_max requires: // buffer_atomic_max requires:
......
...@@ -358,13 +358,15 @@ struct DynamicBuffer ...@@ -358,13 +358,15 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, int32_t> || is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
is_same_v<remove_cvref_t<scalar_t>, float> || is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0); (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>; bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, float> || is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0); (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#else #else
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
......
...@@ -18,134 +18,82 @@ namespace device { ...@@ -18,134 +18,82 @@ namespace device {
namespace instance { namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif #endif
template <typename ADataType, template <typename ADataType,
...@@ -154,7 +102,7 @@ template <typename ADataType, ...@@ -154,7 +102,7 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout, ALayout,
BLayout, BLayout,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -167,17 +115,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -167,17 +115,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>> ck::tensor_operation::element_wise::MultiplyMultiply>>
{ {
using DeviceOp = DeviceGemmMultipleD<ALayout, using DeviceOp =
BLayout, DeviceGemmMultipleDSplitK<ALayout,
Tuple<Row, Col>, BLayout,
CLayout, Tuple<Row, Col>,
ADataType, CLayout,
BDataType, ADataType,
Tuple<F32, F32>, BDataType,
CDataType, Tuple<F32, F32>,
ck::tensor_operation::element_wise::PassThrough, CDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>; ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -194,24 +143,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -194,24 +143,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
} }
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment