Commit ba1756e3 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'develop' into feature/add-permute-device-op

parents ea23062c 7c788e10
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_custom_target(example_batched_gemm_scale_softmax_gemm)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
......@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
......@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 128;
ck::index_t N = 1024;
ck::index_t M = 120;
ck::index_t N = 1000;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t StrideA = -1;
......
......@@ -55,7 +55,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ALayout,
......@@ -73,7 +73,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
GemmSpec,
1,
256,
128, // MPerBlock
......@@ -144,8 +144,8 @@ int main(int argc, char* argv[])
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t M = 1020;
ck::index_t N = 1020;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
......
......@@ -649,6 +649,9 @@ struct BlockwiseGemmXdlops_v2
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
......
......@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
throw std::runtime_error("wrong! unsupported argument");
}
const index_t grid_size =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{
struct ProblemDesc
{
// Overall problem shape
index_t M;
index_t N;
index_t K;
index_t O;
index_t Batch;
// Stride for A/B0/B1; layout determined by template args
index_t StrideA;
index_t StrideB0;
index_t StrideB1;
index_t BatchStrideA;
index_t BatchStrideB0;
index_t BatchStrideB1;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides;
};
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
return is_valid;
}
// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
// workgroups assigned to a given gemm problem have top index offsetted to range [0,
// grid_size_per_gemm]
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
} // namespace ck
......@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // O_new
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new;
c_thread_buf(I) = c_new; // O_new
});
});
......
......@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0];
#else
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 \n \
"
: "=v"(y0)
: "v"(x0), "v"(x1));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1] \n \
"
: "=v"(y1)
: "v"(x0), "v"(x1));
constexpr int32_t m0 = 0x05040100;
constexpr int32_t m1 = 0x07060302;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
#endif
}
......@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
// clang-format off
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast<int32_t>(x1)), "v"(bit_cast<int32_t>(x0)), "s"(m0));
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast<int32_t>(x3)), "v"(bit_cast<int32_t>(x2)), "s"(m0));
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z0) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m1));
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z1) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m2));
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast<int32_t>(x1)), "v"(bit_cast<int32_t>(x0)), "s"(m3));
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast<int32_t>(x3)), "v"(bit_cast<int32_t>(x2)), "s"(m3));
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z2) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m1));
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z3) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m2));
// clang-format on
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
y0 = bit_cast<int8x4_t>(z0);
y1 = bit_cast<int8x4_t>(z1);
......
......@@ -55,6 +55,22 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
// clang-format on
>;
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances =
std::tuple<
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
Col,
......@@ -73,6 +89,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances{});
}
} // namespace instance
......
......@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this->Run();
}
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
{
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 16},
{256, 64, 160, 64, 16},
{1024, 1024, 80, 80, 16},
{1024, 64, 80, 64, 16},
{4096, 4096, 40, 40, 16},
{4096, 64, 40, 64, 16}};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented
......
......@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
};
std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
{256, 256, 160, 160, 4},
{256, 64, 160, 64, 4},
{1024, 1024, 80, 80, 2},
{1024, 64, 80, 64, 2},
{4096, 4096, 40, 40, 1},
{4096, 64, 40, 64, 1}};
bool bench_ = false;
bool verify_ = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment