"...composable_kernel_rocm.git" did not exist on "fb54b55b2a5a611c8f50831a13df08be2e10aa59"
Commit ceeca94a authored by rocking's avatar rocking
Browse files

Merge commit 'd6709dc3' into gemm_layernorm_welford

parents 2732d06c d6709dc3
...@@ -352,6 +352,8 @@ def runCKProfiler(Map conf=[:]){ ...@@ -352,6 +352,8 @@ def runCKProfiler(Map conf=[:]){
archiveArtifacts "perf_conv_bwd_data_${gpu_arch}.log" archiveArtifacts "perf_conv_bwd_data_${gpu_arch}.log"
archiveArtifacts "perf_gemm_bilinear_${gpu_arch}.log" archiveArtifacts "perf_gemm_bilinear_${gpu_arch}.log"
archiveArtifacts "perf_reduction_${gpu_arch}.log" archiveArtifacts "perf_reduction_${gpu_arch}.log"
archiveArtifacts "perf_splitK_gemm_${gpu_arch}.log"
archiveArtifacts "perf_onnx_gemm_${gpu_arch}.log"
// stash perf files to master // stash perf files to master
stash name: "perf_gemm_${gpu_arch}.log" stash name: "perf_gemm_${gpu_arch}.log"
stash name: "perf_resnet50_N256_${gpu_arch}.log" stash name: "perf_resnet50_N256_${gpu_arch}.log"
...@@ -362,6 +364,8 @@ def runCKProfiler(Map conf=[:]){ ...@@ -362,6 +364,8 @@ def runCKProfiler(Map conf=[:]){
stash name: "perf_conv_bwd_data_${gpu_arch}.log" stash name: "perf_conv_bwd_data_${gpu_arch}.log"
stash name: "perf_gemm_bilinear_${gpu_arch}.log" stash name: "perf_gemm_bilinear_${gpu_arch}.log"
stash name: "perf_reduction_${gpu_arch}.log" stash name: "perf_reduction_${gpu_arch}.log"
stash name: "perf_splitK_gemm_${gpu_arch}.log"
stash name: "perf_onnx_gemm_${gpu_arch}.log"
//we will process results on the master node //we will process results on the master node
} }
else{ else{
...@@ -442,6 +446,8 @@ def process_results(Map conf=[:]){ ...@@ -442,6 +446,8 @@ def process_results(Map conf=[:]){
unstash "perf_conv_bwd_data_${gpu_arch}.log" unstash "perf_conv_bwd_data_${gpu_arch}.log"
unstash "perf_gemm_bilinear_${gpu_arch}.log" unstash "perf_gemm_bilinear_${gpu_arch}.log"
unstash "perf_reduction_${gpu_arch}.log" unstash "perf_reduction_${gpu_arch}.log"
unstash "perf_splitK_gemm_${gpu_arch}.log"
unstash "perf_onnx_gemm_${gpu_arch}.log"
sh "./process_qa_data.sh ${gpu_arch}" sh "./process_qa_data.sh ${gpu_arch}"
} }
else{ else{
......
add_executable(client_softmax4d softmax4d.cpp)
target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax.hpp"
using InDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 4;
constexpr int NumReduceDim = 2;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
std::vector<ck::index_t> in_lengths{2, 8, 128, 1024};
std::vector<ck::index_t> in_strides{8 * 128 * 1024, 128 * 1024, 1024, 1};
std::vector<ck::index_t> reduce_dims{2, 3};
ck::index_t num_elements =
std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>());
AccDataType alpha{2.0f};
AccDataType beta{2.0f};
SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
using DeviceOp = ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim)
{
continue;
}
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
&alpha,
&beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_bytes = num_elements * sizeof(InDataType) +
(beta == 0.0f ? 1 : 2) * num_elements * sizeof(OutDataType);
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
&alpha,
&beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
\ No newline at end of file
...@@ -11,3 +11,4 @@ add_subdirectory(02_gemm_add_add_fastgelu) ...@@ -11,3 +11,4 @@ add_subdirectory(02_gemm_add_add_fastgelu)
add_subdirectory(03_gemm_layernorm) add_subdirectory(03_gemm_layernorm)
add_subdirectory(04_contraction) add_subdirectory(04_contraction)
add_subdirectory(05_layernorm) add_subdirectory(05_layernorm)
add_subdirectory(06_softmax)
...@@ -9,37 +9,41 @@ ...@@ -9,37 +9,41 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" #include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
using namespace ck;
using namespace ck::tensor_operation::device; using namespace ck::tensor_operation::device;
using InDataType = ck::half_t; using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 3; constexpr int Rank = 3;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
using DeviceInstance = DeviceSoftmax<InDataType, using DeviceInstance = DeviceSoftmaxImpl<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, PassThrough, // InElementwiseOperation
NumReduceDim, PassThrough, // AccElementwiseOperation
256, // BlockSize Rank,
8, // ClusterM NumReduceDim,
32, // ClusterK 256, // BlockSize
1, // SliceM 8, // ClusterM
8, // SliceK 32, // ClusterK
1, // SrcVecDim (0=M, 1=K) 1, // SliceM
8, // SrcScalarPerVector 8, // SliceK
8>; // OutScalarPerVector 1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector
8>; // OutScalarPerVector
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"verify", required_argument, nullptr, 'v'}, {"verify", required_argument, nullptr, 'v'},
...@@ -196,7 +200,7 @@ int main(int argc, char* argv[]) ...@@ -196,7 +200,7 @@ int main(int argc, char* argv[])
if(args.do_verification) if(args.do_verification)
{ {
using ReferenceInstance = using ReferenceInstance =
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims); auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, reduceDims);
auto invoker = ref.MakeInvoker(); auto invoker = ref.MakeInvoker();
...@@ -220,7 +224,9 @@ int main(int argc, char* argv[]) ...@@ -220,7 +224,9 @@ int main(int argc, char* argv[])
&alpha, &alpha,
&beta, &beta,
in_dev.GetDeviceBuffer(), in_dev.GetDeviceBuffer(),
out_dev.GetDeviceBuffer()); out_dev.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{ {
......
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_custom_target(example_batched_gemm_scale_softmax_gemm)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16)
...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
...@@ -77,7 +77,7 @@ using DeviceGemmInstance = ...@@ -77,7 +77,7 @@ using DeviceGemmInstance =
Acc0ElementOp, Acc0ElementOp,
B1ElementOp, B1ElementOp,
CElementOp, CElementOp,
GemmDefault, GemmSpec,
1, 1,
256, 256,
128, // MPerBlock 128, // MPerBlock
...@@ -166,8 +166,6 @@ int main(int argc, char* argv[]) ...@@ -166,8 +166,6 @@ int main(int argc, char* argv[])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1 = 13;
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
if(argc == 1) if(argc == 1)
{ {
...@@ -204,6 +202,9 @@ int main(int argc, char* argv[]) ...@@ -204,6 +202,9 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K; const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N; const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
......
...@@ -49,14 +49,9 @@ using B0Layout = Col; ...@@ -49,14 +49,9 @@ using B0Layout = Col;
using B1Layout = Row; using B1Layout = Row;
using CLayout = Row; using CLayout = Row;
// When using padded DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle kernel, 2 specs should be set:
// 1. GemmSpecialization should be set to MNPadding(or NPadding in future)
// 2. Acc0ElementOp should be set to ScaleAndResetNaNToMinusInfinity
// Otherwise, wrong result may be produced.
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleAndResetNaNToMinusInfinity; using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
......
...@@ -144,6 +144,17 @@ ...@@ -144,6 +144,17 @@
// workaround: compiler gnerating inefficient ds_write instructions // workaround: compiler gnerating inefficient ds_write instructions
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
// (gfx908 only) workaround: compiler crash in fused kernels on mainline #9110; #10738 seems ok
// error message was "fatal error: error in backend: Error while trying to spill VGPR0 from class
// VGPR_32: Cannot scavenge register without an emergency spill slot!"
// this fall back to less ideal way of handle NPadding in fused attention kernel
#ifdef __gfx908__
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 1
#else
// for __gfx90a__, ...
#define CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER 0
#endif // __gfx908__
// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some // workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
......
...@@ -16,7 +16,8 @@ template <index_t BlockSize, ...@@ -16,7 +16,8 @@ template <index_t BlockSize,
typename AccDataType, typename AccDataType,
typename ThreadMap_M_K, // thread_id to m_k typename ThreadMap_M_K, // thread_id to m_k
typename ThreadClusterDesc_M_K, typename ThreadClusterDesc_M_K,
typename ThreadSliceDesc_M_K> typename ThreadSliceDesc_M_K,
bool IgnoreNaN = false>
struct BlockwiseSoftmax struct BlockwiseSoftmax
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -27,11 +28,33 @@ struct BlockwiseSoftmax ...@@ -27,11 +28,33 @@ struct BlockwiseSoftmax
using ThreadSliceDesc_M = decltype( using ThreadSliceDesc_M = decltype(
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType, using ThreadwiseMaxReduce = typename conditional<
ThreadSliceDesc_M_K, IgnoreNaN,
ThreadSliceDesc_M, ThreadwiseReduction<AccDataType,
reduce::Max, ThreadSliceDesc_M_K,
false>; ThreadSliceDesc_M,
reduce::Max,
false,
detail::AccumulateWithNanIgnore<reduce::Max, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max,
false>>::type;
using ThreadwiseSumReduce = typename conditional<
IgnoreNaN,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false,
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>,
ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
...@@ -49,12 +72,6 @@ struct BlockwiseSoftmax ...@@ -49,12 +72,6 @@ struct BlockwiseSoftmax
reduce::Add, reduce::Add,
false>; false>;
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>;
using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>; using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>;
template <typename CThreadBuffer, typename WorkspaceBuffer> template <typename CThreadBuffer, typename WorkspaceBuffer>
...@@ -74,7 +91,9 @@ struct BlockwiseSoftmax ...@@ -74,7 +91,9 @@ struct BlockwiseSoftmax
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM)); in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf(iM));
}); });
}); });
......
...@@ -456,8 +456,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -456,8 +456,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
block_2_ctile_map_, block_2_ctile_map_))
raw_lengths_m_n_k_o_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -508,8 +507,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -508,8 +507,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_))
arg.raw_lengths_m_n_k_o_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
...@@ -628,8 +626,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -628,8 +626,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_);
arg.raw_lengths_m_n_k_o_);
} }
// polymorphic // polymorphic
......
...@@ -194,6 +194,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -194,6 +194,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -209,92 +212,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -209,92 +212,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1; const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto AK0 = K / AK1;
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = return transform_tensor_descriptor(a_grid_desc_m_k,
transform_tensor_descriptor(a_grid_desc_m_k, make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), make_pass_through_transform(M)),
make_pass_through_transform(MRaw)), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
...@@ -312,84 +241,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -312,84 +241,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; const auto N = b_grid_desc_n_k.GetLength(I0);
} const auto K = b_grid_desc_n_k.GetLength(I1);
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto BK0 = K / BK1;
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return transform_tensor_descriptor(b_grid_desc_n_k,
} make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1 // Args: Gemm1KRaw, Gemm1NRaw, StrideB1
...@@ -408,47 +271,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -408,47 +271,19 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock; const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, Gemm1KPerBlock) * Gemm1KPerBlock;
const auto NPad = N - NRaw; const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto KPad = K - KRaw; const auto K = b1_grid_desc_n_k.GetLength(I1);
// TODO: implement finer-grained padding const auto B1K0 = K / B1K1;
if constexpr(GemmSpec == GemmSpecialization::Default)
{
const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( return transform_tensor_descriptor(
b1_grid_desc_nraw_kraw, b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
else
{
// pad both B1N and B1K
const auto B1K0 = K / B1K1;
const auto b1_grid_desc_n_k =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
} }
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
...@@ -662,7 +497,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -662,7 +497,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
matrix_padder.PadN>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -711,7 +547,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -711,7 +547,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_} BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -745,6 +584,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -745,6 +584,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_;
index_t c_stride_lowest_;
}; };
// Invoker // Invoker
...@@ -849,9 +693,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -849,9 +693,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_g_m_n_.GetLength(I1); const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_g_m_n_.GetLength(I2); const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
...@@ -859,7 +703,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -859,7 +703,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// TODO: Check A/B0/B1 length & stride and scalar per vector // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if(arg.c_stride_lowest_ != 1)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
...@@ -996,7 +868,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -996,7 +868,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ">"; << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -198,6 +199,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -198,6 +199,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -213,92 +221,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -213,92 +221,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto M = a_grid_desc_m_k.GetLength(I0);
a_grid_desc_mraw_kraw, const auto K = a_grid_desc_m_k.GetLength(I1);
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = const auto AK0 = K / AK1;
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1; return transform_tensor_descriptor(a_grid_desc_m_k,
} make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
else make_pass_through_transform(M)),
{ make_tuple(Sequence<1>{}, Sequence<0>{}),
// not pad M or K make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
...@@ -316,84 +250,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -316,84 +250,18 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; const auto N = b_grid_desc_n_k.GetLength(I0);
} const auto K = b_grid_desc_n_k.GetLength(I1);
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 = const auto BK0 = K / BK1;
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1; return transform_tensor_descriptor(b_grid_desc_n_k,
} make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1 // Args: Gemm1KRaw, Gemm1NRaw, StrideB1
...@@ -412,47 +280,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -412,47 +280,19 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock; const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, Gemm1KPerBlock) * Gemm1KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
// TODO: implement finer-grained padding const auto N = b1_grid_desc_n_k.GetLength(I0);
if constexpr(GemmSpec == GemmSpecialization::Default) const auto K = b1_grid_desc_n_k.GetLength(I1);
{
const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto B1K0 = K / B1K1;
b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1; return transform_tensor_descriptor(
} b1_grid_desc_n_k,
else make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
{ make_pass_through_transform(N)),
// pad both B1N and B1K make_tuple(Sequence<1>{}, Sequence<0>{}),
const auto B1K0 = K / B1K1; make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto b1_grid_desc_n_k =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -470,47 +310,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -470,47 +310,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
} }
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
...@@ -617,7 +417,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -617,7 +417,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
matrix_padder.PadN>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -661,7 +462,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -661,7 +462,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC} compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -694,6 +496,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -694,6 +496,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
}; };
// Invoker // Invoker
...@@ -797,6 +602,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -797,6 +602,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return false; return false;
} }
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
...@@ -913,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -913,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< Gemm1NPerBlock << ", " << Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", " << Gemm1KPerBlock << ", "
<< B1K1 << ">"; << B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -3,19 +3,10 @@ ...@@ -3,19 +3,10 @@
#pragma once #pragma once
#include <iostream> #include <memory>
#include <sstream> #include <vector>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,227 +15,54 @@ namespace device { ...@@ -24,227 +15,54 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
index_t Rank, typename InElementwiseOp,
index_t NumReduceDim, typename AccElementwiseOp,
index_t BlockSize, index_t Rank>
index_t MThreadClusterSize, struct DeviceSoftmax : public BaseOperator
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceSoftmax : public DeviceNormalization
{ {
static constexpr index_t kRank = Rank; //
static constexpr index_t kNumReduceDim = NumReduceDim; // @brief Makes a pointer to Argument class.
//
virtual index_t GetRank() const override { return kRank; } // @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } // @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
using PassThrough = tensor_operation::element_wise::PassThrough; // value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// Used for freeloading of some handy functions from DeviceReduceMultiBlock // value as type AccDataType
using Reduction = DeviceReduceMultiBlock<InDataType, // @param[in] in_dev Typeless const pointer in device memory storing the input
AccDataType, // tensor
OutDataType, // @param out_dev Typeless pointer in device memory storing the output tensor
Rank, // @param[in] in_elementwise_op The input elementwise operation.
NumReduceDim, // @param[in] acc_elementwise_op The accumulation elementwise operation.
reduce::Add, //
PassThrough, // InElementwiseOperation // @return Unique pointer to the Argument class.
PassThrough, // AccElementwiseOperation //
InMemoryDataOperationEnum::Set, virtual std::unique_ptr<BaseArgument>
false, // PropagateNan MakeArgumentPointer(const std::vector<index_t> inLengths,
false, // OutputIndex const std::vector<index_t> inStrides,
false, // HaveIndexInputIfOutputIndex const std::vector<int> reduceDims,
BlockSize, const void* alpha,
MThreadClusterSize, const void* beta,
KThreadClusterSize, const void* in_dev,
MThreadSliceSize, void* out_dev,
KThreadSliceSize, InElementwiseOp in_elementwise_op,
InSrcVectorDim, AccElementwiseOp acc_elementwise_op) = 0;
InSrcVectorSize,
1>; // OutDstVectorSize virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual index_t GetRank() const = 0;
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); virtual index_t GetNumReduceDim() const = 0;
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims,
AccDataType alpha,
AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev)
: Reduction::Argument(inLengths,
inStrides,
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
PassThrough{},
PassThrough{}),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType alpha_;
AccDataType beta_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once =
in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_))
{
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false;
}
return true;
};
// inLengths: input tensor extent(s) from high to low dimension
// inStrides: input tensor stride(s) from high to low dimension
// reduceDims: the dimension(s) the softmax normalization operate on
// alpha: typeless pointer in host memory storing the alpha scaling value as type AccDataType
// beta: typeless pointer in host memory storing the beta scaling value as type AccDataType
// in_dev: typeless const pointer in device memory storing the input tensor
// out_dev: typeless pointer in device memory storing the output tensor
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
*static_cast<const AccDataType*>(alpha),
*static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
}; };
template <typename InDataType,
typename AccDataType,
typename OutDataType,
typename InElementwiseOp,
typename AccElementwiseOp,
index_t Rank>
using DeviceSoftmaxPtr = std::unique_ptr<
DeviceSoftmax<InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
typename InElementwiseOp,
typename AccElementwiseOp,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
InElementwiseOp,
AccElementwiseOp,
Rank>
{
static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim;
virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using Reduction = DeviceReduceMultiBlock<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
reduce::Add,
InElementwiseOp,
AccElementwiseOp,
InMemoryDataOperationEnum::Set,
false, // PropagateNan
false, // OutputIndex
false, // HaveIndexInputIfOutputIndex
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
1>; // OutDstVectorSize
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
false>;
using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType,
AccDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
true>;
struct Argument : public Reduction::Argument
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims,
AccDataType alpha,
AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
: Reduction::Argument(inLengths,
inStrides,
{},
{},
reduceDims,
0.0f, // alpha
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
in_elementwise_op,
acc_elementwise_op),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{
// std::cout << "blkGroupSize= " << this->blkGroupSize
// << ", numBlockTileIteration= " << this->numBlockTileIteration
// << ", gridSize=" << this->gridSize
// << ", invariant_total_length=" << this->invariant_total_length <<
// std::endl;
}
AccDataType alpha_;
AccDataType beta_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once =
in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>
: kernel_softmax<GridwiseSoftmaxGeneric,
InDataType,
OutDataType,
AccDataType,
GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(!Reduction::IsSupportedArgument(p_arg_))
{
return false;
}
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false;
}
return true;
};
//
// @brief Makes a pointer to Argument class.
//
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
// @param[in] in_elementwise_op The input elementwise operation.
// @param[in] acc_elementwise_op The accumulation elementwise operation.
//
// @return Unique pointer to the Argument class.
//
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
const void* in_dev,
void* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
*static_cast<const AccDataType*>(alpha),
*static_cast<const AccDataType*>(beta),
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -200,8 +200,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -200,8 +200,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n, const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map)
const std::vector<index_t>& lengths_m_n_k_o)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -217,13 +216,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -217,13 +216,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false; return false;
} }
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
const auto KRaw = lengths_m_n_k_o[2];
if(!(KRaw % AK1 == 0 && KRaw % BK1 == 0))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) Gemm1N % Gemm1NPerBlock == 0))
{ {
...@@ -602,9 +594,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -602,9 +594,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max( // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math::lcm(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1), // selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); //
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -75,7 +75,8 @@ template <typename FloatAB, ...@@ -75,7 +75,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched,
bool PadN>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -330,6 +331,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -330,6 +331,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
}; };
template <bool Pred>
struct ElementOpPredicatedResetNaNToMinusInf;
template <>
struct ElementOpPredicatedResetNaNToMinusInf<true>
{
template <typename ElementOp, typename OutT, typename InT>
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
{
if(ck::math::isnan(x))
{
y = -ck::NumericLimits<float>::Infinity();
}
else
{
op(y, x);
}
}
};
template <>
struct ElementOpPredicatedResetNaNToMinusInf<false>
{
template <typename ElementOp, typename OutT, typename InT>
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
{
op(y, x);
}
};
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -348,14 +379,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -348,14 +379,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf =
p_a_grid, conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(), p_a_grid,
NumericLimits<FloatAB>::QuietNaN()); a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( NumericLimits<FloatAB>::QuietNaN()),
p_b_grid, make_dynamic_buffer<AddressSpaceEnum::Global>(
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(), p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()));
NumericLimits<FloatAB>::QuietNaN()); const auto b_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()));
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -608,9 +645,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -608,9 +645,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max( // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math::lcm(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1), // selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); //
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
...@@ -680,7 +725,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -680,7 +725,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc, FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor), decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n), decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
,
true
#endif
>{};
const index_t num_gemm1_k_block_outer_loop = const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
...@@ -721,8 +771,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -721,8 +771,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
num_k_block_main_loop); num_k_block_main_loop);
// Acc0 elementwise Op // Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}( static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
...@@ -32,6 +32,20 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_i ...@@ -32,6 +32,20 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Col,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename B0Layout, typename B0Layout,
typename B1Layout, typename B1Layout,
...@@ -82,6 +96,12 @@ struct DeviceOperationInstanceFactory< ...@@ -82,6 +96,12 @@ struct DeviceOperationInstanceFactory<
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
op_ptrs);
}
} }
return op_ptrs; return op_ptrs;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_softmax_f16_f16_rank3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>&);
void add_device_softmax_f16_f16_rank4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>&);
void add_device_softmax_f32_f32_rank3_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>&);
void add_device_softmax_f32_f32_rank4_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>&);
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>>
{
using DeviceOp =
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F16>)
{
if constexpr(Rank == 3)
add_device_softmax_f16_f16_rank3_instances(op_ptrs);
else if constexpr(Rank == 4)
add_device_softmax_f16_f16_rank4_instances(op_ptrs);
}
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>)
{
if constexpr(Rank == 3)
add_device_softmax_f32_f32_rank3_instances(op_ptrs);
else if constexpr(Rank == 4)
add_device_softmax_f32_f32_rank4_instances(op_ptrs);
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -78,6 +78,7 @@ target_include_directories(device_operations PUBLIC ...@@ -78,6 +78,7 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/problem_transform> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/problem_transform>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/device> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/device>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/device/impl>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/grid> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/grid>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/block> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/block>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/warp> $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/warp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment