Commit ab04f22f authored by Jing Zhang's avatar Jing Zhang
Browse files

add c_permute

parent ef18bd98
......@@ -6,16 +6,14 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -44,7 +42,7 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermutationXdl
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl
//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
......@@ -52,13 +50,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermu
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceBatchedGemmCPermutationInstance =
ck::tensor_operation::host::ReferenceBatchedGemmCPermutation<ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......@@ -66,26 +59,27 @@ int main(int argc, char* argv[])
int init_method = 1;
bool time_kernel = false;
const int M0 = rand() % 4 + 1;
const int M1 = 256;
const int N0 = rand() % 4 + 1;
const int N1 = 256;
// const int M = 88;
// const int N = 64;
// const int K = 88;
const int M = M0 * N1;
const int N = N0 * N1;
const int K = 128 * (rand() % 4 + 1);
const int M = 256;
const int N = 128;
const int K = 64;
const int stride_A = K;
const int stride_B = K;
// output layout [M0, N0, M1, N1]
const int stride_M0 = N1 * M1 * N0;
const int stride_M1 = N1;
const int stride_N0 = N1 * M1;
const int stride_N1 = 1;
const int G0 = 1024;
const int G1 = 10;
const int batch_count = G0 * G1;
int batch_count = rand() % 16 + 1;
// output layout - [G0, M, G1, N]
const int stride_B0 = M * G1 * N;
const int stride_B1 = N;
const int stride_M = G1 * N;
const int stride_N = 1;
if(argc == 4)
{
......@@ -102,8 +96,8 @@ int main(int argc, char* argv[])
}
// GEMM shape
ck::tensor_operation::device::GemmTransposeDesc gemm_transpose_desc{
M, N, K, stride_A, stride_B, M0, M1, N0, N1, stride_M0, stride_M1, stride_N0, stride_N1};
ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{
G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
......@@ -125,30 +119,28 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
auto f_host_c_tensor_descriptor = [](std::size_t batch_count_,
std::size_t M0_,
std::size_t M1_,
std::size_t N0_,
std::size_t N1_,
std::size_t StrideM0_,
std::size_t StrideM1_,
std::size_t StrideN0_,
std::size_t StrideN1_) {
auto f_host_c_tensor_descriptor = [](std::size_t B0_,
std::size_t B1_,
std::size_t M_,
std::size_t N_,
std::size_t stride_B0_,
std::size_t stride_B1_,
std::size_t stride_M_,
std::size_t stride_N_) {
return HostTensorDescriptor(
std::vector<std::size_t>({batch_count_, M0_, M1_, N0_, N1_}),
std::vector<std::size_t>(
{M0_ * M1_ * N0_ * N1_, StrideM0_, StrideM1_, StrideN0_, StrideN1_}));
std::vector<std::size_t>({B0_, B1_, M_, N_}),
std::vector<std::size_t>({stride_B0_, stride_B1_, stride_M_, stride_N_}));
};
Tensor<CDataType> c_g_m0_m1_n0_n1_host_result(f_host_c_tensor_descriptor(
batch_count, M0, M1, N0, N1, stride_M0, stride_M1, stride_N0, stride_N1));
Tensor<CDataType> c_g0_g1_m_n_host_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N));
Tensor<CDataType> c_g_m0_m1_n0_n1_device_result(f_host_c_tensor_descriptor(
batch_count, M0, M1, N0, N1, stride_M0, stride_M1, stride_N0, stride_N1));
Tensor<CDataType> c_g0_g1_m_n_device_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m0_m1_n0_n1_host_result.mDesc << std::endl;
std::cout << "c_g0_g1_m_n: " << c_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
......@@ -165,8 +157,7 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) *
c_g_m0_m1_n0_n1_device_result.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
......@@ -182,7 +173,12 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
gemm_transpose_desc,
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
a_element_op,
b_element_op,
c_element_op,
......@@ -213,22 +209,36 @@ int main(int argc, char* argv[])
if(do_verification)
{
c_device_buf.FromDevice(c_g_m0_m1_n0_n1_device_result.mData.data());
c_device_buf.FromDevice(c_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmCPermutationInstance{};
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
auto ref_argument = ref_batched_gemm.MakeArgument(a_g_m_k,
b_g_k_n,
c_g_m0_m1_n0_n1_host_result,
a_element_op,
b_element_op,
c_element_op);
Tensor<CDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
pass = ck::utils::check_err(c_g_m0_m1_n0_n1_host_result.mData,
c_g_m0_m1_n0_n1_device_result.mData,
for(int g0 = 0; g0 < G0; g0++)
{
for(int g1 = 0; g1 < G1; g1++)
{
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
int g = g0 * G1 + g1;
c_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
}
}
}
}
pass = ck::utils::check_err(c_g0_g1_m_n_host_result.mData,
c_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c");
}
......
......@@ -10,17 +10,17 @@ namespace device {
struct BatchedGemmCPermuteDesc
{
ck::index_t B0_, B1_, M_, N_;
ck::index_t stride_B0_, stride_B1_, stride_M_, stride_N_;
ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmCPermutate : public BaseOperator
struct DeviceBatchedGemmCPermute : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
......@@ -28,7 +28,7 @@ struct DeviceBatchedGemmCPermutate : public BaseOperator
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......@@ -40,10 +40,8 @@ struct DeviceBatchedGemmCPermutate : public BaseOperator
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmCPermutatePtr =
std::unique_ptr<DeviceBatchedGemmCPermutate<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -13,8 +13,6 @@
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -41,7 +39,7 @@ namespace device {
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermutate and GroupedGemm (and the corresponding GEMM fusion).
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
......@@ -160,8 +158,7 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermutateXdl
: public DeviceBatchedGemmCPermutate<AElementwiseOperation,
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
......@@ -247,14 +244,10 @@ struct DeviceBatchedGemmCPermutateXdl
}
}
static auto MakeCGridDescriptor_M_N(index_t M,
index_t N,
index_t stride_M,
index_t stride_N)
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t stride_M, index_t stride_N)
{
const auto c_grid_desc_m_n = [&]() {
return make_naive_tensor_descriptor(
make_tuple(M, N), make_tuple(stride_M, stride_N));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(stride_M, stride_N));
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
......@@ -279,16 +272,53 @@ struct DeviceBatchedGemmCPermutateXdl
}
}
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
index_t G1,
index_t M,
index_t N,
index_t stride_G0,
index_t stride_G1,
index_t stride_M,
index_t stride_N)
{
const auto e_grid_desc_g0_g1_m_n = [&]() {
return make_naive_tensor_descriptor(
make_tuple(G0, G1, M, N), make_tuple(stride_G0, stride_G1, stride_M, stride_N));
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
e_grid_desc_g0_g1_m_n,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_right_pad_transform(M, PadM),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else
{
return e_grid_desc_g0_g1_m_n;
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A,
index_t Batchstride_B,
index_t BatchStrideC)
: Batchstride_A_(Batchstride_A), Batchstride_B_(Batchstride_B), BatchStrideC_(BatchStrideC)
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
: Batchstride_A_(Batchstride_A),
Batchstride_B_(Batchstride_B),
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
{
}
......@@ -304,13 +334,16 @@ struct DeviceBatchedGemmCPermutateXdl
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
index_t b0 = g_idx / G1;
index_t b1 = g_idx % G1;
return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
}
private:
index_t Batchstride_A_;
index_t Batchstride_B_;
index_t BatchStrideC_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
};
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
......@@ -383,20 +416,29 @@ struct DeviceBatchedGemmCPermutateXdl
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{DeviceBatchedGemmCPermutateXdl::MakeAGridDescriptor_K0_M_K1(
M, K, stride_A)},
b_grid_desc_k0_n_k1_{DeviceBatchedGemmCPermutateXdl::MakeBGridDescriptor_K0_N_K1(
K, N, stride_B)},
c_grid_desc_m_n_{DeviceBatchedGemmCPermutateXdl::MakeCGridDescriptor_M_N(
a_grid_desc_k0_m_k1_{
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_K0_M_K1(M, K, stride_A)},
b_grid_desc_k0_n_k1_{
DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_K0_N_K1(K, N, stride_B)},
c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N(
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N(
batched_gemm_c_permute_desc.G0_,
batched_gemm_c_permute_desc.G1_,
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_G0_,
batched_gemm_c_permute_desc.stride_G1_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......@@ -422,6 +464,7 @@ struct DeviceBatchedGemmCPermutateXdl
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
......@@ -433,7 +476,7 @@ struct DeviceBatchedGemmCPermutateXdl
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceBatchedGemmCPermutateXdl::Argument;
using Argument = DeviceBatchedGemmCPermuteXdl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
......@@ -456,7 +499,7 @@ struct DeviceBatchedGemmCPermutateXdl
arg.block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemmCPermutate_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
"setting");
}
......@@ -473,8 +516,8 @@ struct DeviceBatchedGemmCPermutateXdl
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceBatchedGemmCPermutateXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmCPermutateXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
......@@ -574,7 +617,8 @@ struct DeviceBatchedGemmCPermutateXdl
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
......@@ -615,7 +659,7 @@ struct DeviceBatchedGemmCPermutateXdl
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmCPermutateXdl"
str << "DeviceBatchedGemmCPermuteXdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment