Commit 930b2872 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

best performing kernel for GEMV codex problem with M=1 with inverted B matrix

parents a1e17d18 a4f72a31
...@@ -10,7 +10,7 @@ using CDataType = ck::half_t; ...@@ -10,7 +10,7 @@ using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row; // Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -19,9 +19,9 @@ using CElementOp = PassThrough; ...@@ -19,9 +19,9 @@ using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 8 // K1PerThread:2,4,8 #define K1 4
#define K0 4 // K0PerBlock:1,2,3,4...32 #define K0 3
#define N1 2 // Nperthread:2,4,8 #define N1 2
#define B 64 // block-size:64 #define B 64 // block-size:64
// clang-format off // clang-format off
...@@ -31,7 +31,7 @@ using DeviceGemvInstance = ck::tensor_operation::device::deviceGemvDl/* ...@@ -31,7 +31,7 @@ using DeviceGemvInstance = ck::tensor_operation::device::deviceGemvDl/*
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | | // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/ //< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, 1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1, 1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, K1, S<0, 1, 2, 3, 4, 5>, 5, N1>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, 1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1, 1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -7,20 +7,114 @@ add_custom_target(examples) ...@@ -7,20 +7,114 @@ add_custom_target(examples)
function(add_example_executable EXAMPLE_NAME FILE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) set(result 1)
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) if(DEFINED DTYPES)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN}) foreach(source IN LISTS FILE_NAME)
add_dependencies(examples ${EXAMPLE_NAME}) set(test 0)
add_dependencies(check ${EXAMPLE_NAME}) foreach(type IN LISTS DTYPES)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable EXAMPLE_NAME) endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) set(result 1)
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) if(DEFINED DTYPES)
add_dependencies(examples ${EXAMPLE_NAME}) foreach(source IN LISTS FILE_NAME)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable_no_testing EXAMPLE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME)
# add all example subdir # add all example subdir
......
...@@ -34,6 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -34,6 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
#endif #endif
// warm up // warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
const int nrepeat = 10; const int nrepeat = 10;
#if DEBUG_LOG #if DEBUG_LOG
...@@ -50,6 +51,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -50,6 +51,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
for(int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
} }
hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
...@@ -64,11 +66,13 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -64,11 +66,13 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
else else
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0; return 0;
} }
#else #else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0; return 0;
#endif #endif
...@@ -101,6 +105,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -101,6 +105,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// warm up // warm up
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
const int nrepeat = 10; const int nrepeat = 10;
#if DEBUG_LOG #if DEBUG_LOG
...@@ -118,6 +123,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -118,6 +123,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
{ {
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
} }
hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
...@@ -133,11 +139,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -133,11 +139,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
{ {
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0; return 0;
} }
#else #else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0; return 0;
#endif #endif
......
...@@ -28,7 +28,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) ...@@ -28,7 +28,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc, typename BK0NK1BlockDesc,
...@@ -58,7 +59,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -58,7 +59,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -294,9 +295,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -294,9 +295,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -318,25 +319,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -318,25 +319,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
b_thread_buf); b_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) { static_for<0, KPerThread, KPack>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatA>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
}); });
using mfma_input_type = using mfma_input_type_a =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type; typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run( xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
...@@ -356,8 +359,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -356,8 +359,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatAB, FloatA,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPerThread>,
...@@ -366,8 +369,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -366,8 +369,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1, A_K1,
A_K1>; A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatAB, FloatB,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>, Sequence<1, 1, 1, KPerThread>,
...@@ -385,7 +388,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -385,7 +388,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc, typename BK0NK1BlockDesc,
...@@ -397,7 +401,8 @@ template <index_t BlockSize, ...@@ -397,7 +401,8 @@ template <index_t BlockSize,
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS> index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatA,
FloatB,
FloatAcc, FloatAcc,
AK0MK1BlockDesc, AK0MK1BlockDesc,
BK0NK1BlockDesc, BK0NK1BlockDesc,
...@@ -408,7 +413,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -408,7 +413,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
KPack> KPack>
{ {
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatA,
FloatB,
FloatAcc, FloatAcc,
AK0MK1BlockDesc, AK0MK1BlockDesc,
BK0NK1BlockDesc, BK0NK1BlockDesc,
...@@ -440,9 +446,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -440,9 +446,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
...@@ -479,20 +485,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -479,20 +485,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatA, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_vec.template AsType<FloatA>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, 0, 0, k_ + i))>{}]; make_tuple(m0, 0, 0, k_ + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, 0, 0, k_ + i))>{}]; make_tuple(n0, 0, 0, k_ + i))>{}];
}); });
using mfma_input_type = using mfma_input_type_a =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type; typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
...@@ -514,8 +522,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -514,8 +522,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// TODO: insert setprio in more precise manner since we // TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call // could have more than >1 MFMA instructions in single call
xdlops_gemm.template Run( xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
{ {
...@@ -541,8 +549,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -541,8 +549,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{})); make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatAB, FloatA,
decltype(a_block_desc_m0_m1_m2_k), decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>, Sequence<1, 1, 1, KPerInnerLoop>,
...@@ -551,8 +559,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -551,8 +559,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
A_K1, A_K1,
A_K1>; A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatAB, FloatB,
decltype(b_block_desc_n0_n1_n2_k), decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>, Sequence<1, 1, 1, KPerInnerLoop>,
...@@ -568,7 +576,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -568,7 +576,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
}; };
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename AK0MK1BlockDesc, typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc, typename BK0NK1BlockDesc,
...@@ -583,7 +592,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -583,7 +592,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
if constexpr(LoopSched == LoopScheduler::Default) if constexpr(LoopSched == LoopScheduler::Default)
{ {
return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatA,
FloatB,
FloatAcc, FloatAcc,
AK0MK1BlockDesc, AK0MK1BlockDesc,
BK0NK1BlockDesc, BK0NK1BlockDesc,
...@@ -596,7 +606,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -596,7 +606,8 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
else if constexpr(LoopSched == LoopScheduler::Interwave) else if constexpr(LoopSched == LoopScheduler::Interwave)
{ {
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatA,
FloatB,
FloatAcc, FloatAcc,
AK0MK1BlockDesc, AK0MK1BlockDesc,
BK0NK1BlockDesc, BK0NK1BlockDesc,
...@@ -618,26 +629,27 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -618,26 +629,27 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
* 3. configurable k index starting position and step size after each FMA/XDL instruction * 3. configurable k index starting position and step size after each FMA/XDL instruction
*/ */
template <index_t BlockSize, template <
typename FloatAB, index_t BlockSize,
typename FloatAcc, typename FloatAB,
typename ATileDesc, typename FloatAcc,
typename BTileDesc, typename ATileDesc,
typename AMmaTileDesc, typename BTileDesc,
typename BMmaTileDesc, typename AMmaTileDesc,
index_t MPerBlock, typename BMmaTileDesc,
index_t NPerBlock, index_t MPerBlock,
index_t KPerBlock, index_t NPerBlock,
index_t MPerXDL, index_t KPerBlock,
index_t NPerXDL, index_t MPerXDL,
index_t MRepeat, index_t NPerXDL,
index_t NRepeat, index_t MRepeat,
index_t KPack, index_t NRepeat,
bool TransposeC = false, index_t KPack,
index_t AMmaKStride = bool TransposeC = false,
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops, index_t AMmaKStride =
index_t BMmaKStride = KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops> index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_v2 struct BlockwiseGemmXdlops_v2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -654,7 +666,8 @@ struct BlockwiseGemmXdlops_v2 ...@@ -654,7 +666,8 @@ struct BlockwiseGemmXdlops_v2
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}; static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template <typename ThreadGroup,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
typename ThreadTransferDstResetCoordinateAfterRunFlags>
struct ThreadGroupTensorSliceTransfer_v7r2
{
static constexpr index_t nDim =
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
using Index = MultiIndex<nDim>;
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
__device__ constexpr ThreadGroupTensorSliceTransfer_v7r2(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
"wrong!");
static_for<0, nSrc, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
"wrong!");
});
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
"wrong!");
});
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
Number<nSrc>{});
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs);
}
}
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <typename DstBuffers>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs));
}
}
template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
}
template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
}
}
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
static_for<0, SrcDescs::Size(), 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
}
template <index_t IDst>
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
}
}
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
static_for<0, DstDescs::Size(), 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v7r2<SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
ElementwiseOperation,
DstInMemOps,
decltype(thread_slice_lengths),
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
ThreadTransferSrcResetCoordinateAfterRunFlags,
ThreadTransferDstResetCoordinateAfterRunFlags>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace conv_tensor_rearrange_op {
struct BaseConvTensorRearrangeOp
{
};
struct ImageToColumn : public BaseConvTensorRearrangeOp
{
static constexpr const char* name = "Image to Column";
};
struct ColumnToImage : public BaseConvTensorRearrangeOp
{
static constexpr const char* name = "Column to Image";
};
template <typename Op,
typename std::enable_if<std::is_base_of<BaseConvTensorRearrangeOp, Op>::value,
bool>::type = false>
std::ostream& operator<<(std::ostream& os, const BaseConvTensorRearrangeOp&)
{
os << Op::name;
return os;
}
} // namespace conv_tensor_rearrange_op
} // namespace ck
...@@ -12,21 +12,26 @@ namespace tensor_operation { ...@@ -12,21 +12,26 @@ namespace tensor_operation {
namespace device { namespace device {
/** /**
* \brief Image to column. * \brief Convolution Tensor Rearrange.
* *
* This Device operator converts image ([G, N, Di, Hi, Wi, C]) to the gemm * This Device operator supports conversion image ([G, N, Di, Hi, Wi, C]) to
* problem([N * Do * Ho * Wo, Z * Y * X * C]). G must be equal to 1. * the gemm problem([N * Do * Ho * Wo, Z * Y * X * C]) (Image to Column) and
* conversion gemm form to the image (Column to Image).
*
* Note that G must be equal to 1.
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam InputLayout Input Layout. * \tparam ImageLayout Input Layout.
* \tparam InputDataType Input Data Type. * \tparam InputDataType Input Data Type.
* \tparam OutputDataType Output Data Type. * \tparam OutputDataType Output Data Type.
* \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage.
*/ */
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename InputLayout, typename ImageLayout,
typename InputDataType, typename InputDataType,
typename OutputDataType> typename OutputDataType,
struct DeviceImageToColumn : public BaseOperator typename ConvTensorRearrangeOp>
struct DeviceConvTensorRearrange : public BaseOperator
{ {
/** /**
...@@ -39,8 +44,8 @@ struct DeviceImageToColumn : public BaseOperator ...@@ -39,8 +44,8 @@ struct DeviceImageToColumn : public BaseOperator
* \param input_spatial_lengths Input spatial lengths. * \param input_spatial_lengths Input spatial lengths.
* \param filter_spatial_lengths Filter spatial lengths. * \param filter_spatial_lengths Filter spatial lengths.
* \param output_spatial_lengths Output spatial lengths. * \param output_spatial_lengths Output spatial lengths.
* \param input_g_n_c_wis_strides Input strides in order [G, N, C, D, H, W]. * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W].
* \param output_m_k_strides Output strides. * \param gemm_m_k_strides Gemm form strides.
* \param conv_filter_strides Convolution filter strides. * \param conv_filter_strides Convolution filter strides.
* \param conv_filter_dilations Convolution filter dilations. * \param conv_filter_dilations Convolution filter dilations.
* \param input_left_pads Convolution left pads. * \param input_left_pads Convolution left pads.
...@@ -55,8 +60,8 @@ struct DeviceImageToColumn : public BaseOperator ...@@ -55,8 +60,8 @@ struct DeviceImageToColumn : public BaseOperator
const std::array<index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
const std::array<index_t, 2>& output_m_k_strides, const std::array<index_t, 2>& gemm_m_k_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A0[M, K], B0[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleABD : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
std::array<ck::index_t, NumATensor> StrideAs,
std::array<ck::index_t, NumBTensor> StrideBs,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -29,7 +29,9 @@ template <ck::index_t NDimSpatial, ...@@ -29,7 +29,9 @@ template <ck::index_t NDimSpatial,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation,
typename AComputeType = ADataType,
typename BComputeType = AComputeType>
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
......
...@@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial, ...@@ -20,7 +20,9 @@ template <ck::index_t NDimSpatial,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeight : public BaseOperator struct DeviceGroupedConvBwdWeight : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -29,7 +29,8 @@ template <index_t NDimSpatial, ...@@ -29,7 +29,8 @@ template <index_t NDimSpatial,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation,
typename ComputeType = ADataType>
struct DeviceGroupedConvFwdMultipleD : public BaseOperator struct DeviceGroupedConvFwdMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
......
...@@ -8,6 +8,57 @@ namespace ck { ...@@ -8,6 +8,57 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed
/// to the GroupedGEMM entry point kernel.
///
struct GroupedGemmKernelArguments
{
__host__ __device__ GroupedGemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
...@@ -31,7 +82,28 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout, ...@@ -31,7 +82,28 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; //----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize([[maybe_unused]] BaseArgument* p_arg,
[[maybe_unused]] index_t kbatch) const
{
}
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs([[maybe_unused]] BaseArgument* p_arg,
[[maybe_unused]] const void* p_dev_kernel_args) const
{
}
}; };
} // namespace device } // namespace device
......
...@@ -22,22 +22,22 @@ template <typename InDataType, ...@@ -22,22 +22,22 @@ template <typename InDataType,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceSoftmax : public BaseOperator struct DeviceSoftmax : public BaseOperator
{ {
// ///
// @brief Makes a pointer to Argument class. /// @brief Makes a pointer to Argument class.
// ///
// @param[in] inLengths Input tensor extent(s) from high to low dimension /// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension /// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied /// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha double type value /// @param[in] alpha double type value
// @param[in] beta double type value /// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input /// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor /// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor /// @param out_dev Typeless pointer in device memory storing the output tensor
// @param[in] in_elementwise_op The input elementwise operation. /// @param[in] in_elementwise_op The input elementwise operation.
// @param[in] acc_elementwise_op The accumulation elementwise operation. /// @param[in] acc_elementwise_op The accumulation elementwise operation.
// ///
// @return Unique pointer to the Argument class. /// @return Unique pointer to the Argument class.
// ///
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
......
...@@ -66,7 +66,8 @@ template <typename ALayout, ...@@ -66,7 +66,8 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1, PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = CDataType> typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -131,7 +132,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -131,7 +132,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
PipelineVer, PipelineVer,
ComputeType>; ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
......
...@@ -168,7 +168,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -168,7 +168,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
stream_config.stream_id_)); stream_config.stream_id_));
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
......
...@@ -303,7 +303,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -303,7 +303,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
StrideC, StrideC,
GridwiseGemv::CalculateMPadded(M), GridwiseGemv::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N), GridwiseGemv::CalculateNPadded(N),
GridwiseGemv::CalculateKPadded(K, KBatch), K,
GridwiseGemv::CalculateK0(K, KBatch), GridwiseGemv::CalculateK0(K, KBatch),
KBatch}; // // KBatch}; // //
} }
...@@ -336,7 +336,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout, ...@@ -336,7 +336,7 @@ struct deviceGemvDl : public DeviceGemv<ALayout,
StrideC, StrideC,
GridwiseGemv::CalculateMPadded(M), GridwiseGemv::CalculateMPadded(M),
GridwiseGemv::CalculateNPadded(N), GridwiseGemv::CalculateNPadded(N),
GridwiseGemv::CalculateKPadded(K, KBatch), K,
GridwiseGemv::CalculateK0(K, KBatch), GridwiseGemv::CalculateK0(K, KBatch),
KBatch); // // KBatch); // //
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment