"vscode:/vscode.git/clone" did not exist on "6586485c4a3bde2490613b4e379f7f72ffc35d78"
Commit edc494df authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent 00331ee4
......@@ -119,8 +119,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, B1ElementOp, CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
......@@ -290,7 +295,8 @@ int main(int argc, char* argv[])
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
......
......@@ -35,8 +35,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
......@@ -723,8 +723,7 @@ struct BlockwiseGemmXdlops_v2
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(
Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
......@@ -738,8 +737,7 @@ struct BlockwiseGemmXdlops_v2
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(
const BlockwiseGemmXdlops_v2& other)
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{
}
......
......@@ -38,7 +38,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
kernel_gemm_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
......@@ -405,8 +406,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
{
const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b1_grid_desc_nraw_kraw,
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
......@@ -426,8 +427,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b1_grid_desc_n_k,
const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
......@@ -435,7 +436,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return b1_grid_desc_bk0_n_bk1;
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
......@@ -809,26 +809,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_b1,
p_c,
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA,
StrideB,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op,
b_element_op,
b1_element_op,
return Argument{p_a, p_b, p_b1, p_c, MRaw,
NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, b1_element_op,
c_element_op};
}
......
......@@ -207,7 +207,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB),
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle));
}
......@@ -234,7 +235,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && Gemm1N % Gemm1NPerBlock == 0))
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
......@@ -472,8 +474,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step = make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
const auto a_block_reset_copy_step =
make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
......
......@@ -1154,11 +1154,11 @@ struct ThreadwiseTensorSliceTransfer_v4
{
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
// apply type convert
src_tmp_vector.template AsType<SrcData>()(i) =
src_buf[Number<src_offset>{}];
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
}
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
......@@ -1206,7 +1206,8 @@ template <typename SrcData,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic
{
static constexpr index_t nDim = SliceLengths::Size();
......@@ -1222,7 +1223,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
"wrong! Not divisible");
}
template <typename SrcSliceOriginIdx, typename DstSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
......@@ -1277,7 +1281,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
});
});
}
};
} // namespace ck
......@@ -739,13 +739,15 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr (!TransposeC)
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_b_wave[k], p_a_wave[k], p_c_thread);
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
}
......
......@@ -69,7 +69,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
......
......@@ -163,7 +163,8 @@ struct GeneratorTensor_Diagonal
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
size_t start_dim = dims.size() - NumEffectiveDim;
bool pred = true;
for (size_t i = start_dim + 1; i < dims.size(); i++) {
for(size_t i = start_dim + 1; i < dims.size(); i++)
{
pred &= (dims[start_dim] == dims[i]);
}
return pred ? value : T{0};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment