Commit edc494df authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent 00331ee4
......@@ -48,10 +48,10 @@ using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......@@ -113,14 +113,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
ADataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, B1ElementOp, CElementOp>;
B0DataType,
ADataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
......@@ -179,15 +184,15 @@ int main(int argc, char* argv[])
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
BatchStrideC = std::stoi(argv[16]);
}
else
{
......@@ -282,35 +287,36 @@ int main(int argc, char* argv[])
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{};
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
b1_element_op,
c_element_op);
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......
......@@ -35,8 +35,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
......@@ -694,7 +694,7 @@ struct BlockwiseGemmXdlops_v2
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
......@@ -723,9 +723,8 @@ struct BlockwiseGemmXdlops_v2
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(
Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
......@@ -738,8 +737,7 @@ struct BlockwiseGemmXdlops_v2
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(
const BlockwiseGemmXdlops_v2& other)
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{
}
......
......@@ -38,22 +38,23 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
kernel_gemm_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -162,17 +163,17 @@ template <typename ALayout,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout,
BLayout,
B1Layout,
CLayout,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
BLayout,
B1Layout,
CLayout,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle;
......@@ -405,12 +406,12 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
{
const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
......@@ -426,16 +427,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b1_grid_desc_bk0_n_bk1;
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
......@@ -537,9 +537,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
};
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmGemm_Xdl_CShuffle<
......@@ -809,26 +809,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_b1,
p_c,
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA,
StrideB,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op,
b_element_op,
b1_element_op,
return Argument{p_a, p_b, p_b1, p_c, MRaw,
NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, b1_element_op,
c_element_op};
}
......
......@@ -181,8 +181,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
......@@ -207,7 +207,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB),
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
}
......@@ -234,7 +235,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && Gemm1N % Gemm1NPerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
......@@ -472,8 +474,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step = make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
const auto a_block_reset_copy_step =
make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
......
......@@ -1154,11 +1154,11 @@ struct ThreadwiseTensorSliceTransfer_v4
{
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
// apply type convert
src_tmp_vector.template AsType<SrcData>()(i) =
src_buf[Number<src_offset>{}];
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
}
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
......@@ -1206,7 +1206,8 @@ template <typename SrcData,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic
{
static constexpr index_t nDim = SliceLengths::Size();
......@@ -1222,7 +1223,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
"wrong! Not divisible");
}
template <typename SrcSliceOriginIdx, typename DstSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
......@@ -1277,7 +1281,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
});
});
}
};
} // namespace ck
......@@ -739,13 +739,15 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr (!TransposeC)
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_b_wave[k], p_a_wave[k], p_c_thread);
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
}
......
......@@ -69,7 +69,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
......
......@@ -161,9 +161,10 @@ struct GeneratorTensor_Diagonal
T operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
size_t start_dim = dims.size() - NumEffectiveDim;
bool pred = true;
for (size_t i = start_dim + 1; i < dims.size(); i++) {
size_t start_dim = dims.size() - NumEffectiveDim;
bool pred = true;
for(size_t i = start_dim + 1; i < dims.size(); i++)
{
pred &= (dims[start_dim] == dims[i]);
}
return pred ? value : T{0};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment