Commit edc494df authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent 00331ee4
...@@ -119,8 +119,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -119,8 +119,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host:: using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
ReferenceBatchedGemm<ADataType, B1DataType, CDataType, AccDataType, AElementOp, B1ElementOp, CElementOp>; B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -290,7 +295,8 @@ int main(int argc, char* argv[]) ...@@ -290,7 +295,8 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
......
...@@ -35,8 +35,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) ...@@ -35,8 +35,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
return transform_tensor_descriptor( return transform_tensor_descriptor(
TileDesc_K0_MN_K1{}, TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})), make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple( make_unmerge_transform(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))), make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
} }
...@@ -723,8 +723,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -723,8 +723,7 @@ struct BlockwiseGemmXdlops_v2
using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2( __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex()) Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin) : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{ {
...@@ -738,8 +737,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -738,8 +737,7 @@ struct BlockwiseGemmXdlops_v2
"wrong!"); "wrong!");
} }
__host__ __device__ BlockwiseGemmXdlops_v2( __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{ {
} }
......
...@@ -38,7 +38,8 @@ __global__ void ...@@ -38,7 +38,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
...@@ -405,8 +406,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -405,8 +406,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
{ {
const auto B1K0 = KRaw / B1K1; const auto B1K0 = KRaw / B1K1;
const auto b1_grid_desc_bk0_n_bk1 = const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b1_grid_desc_nraw_kraw, b1_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(NRaw)), make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
...@@ -426,8 +427,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -426,8 +427,8 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b1_grid_desc_bk0_n_bk1 = const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
transform_tensor_descriptor(b1_grid_desc_n_k, b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
...@@ -435,7 +436,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -435,7 +436,6 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
return b1_grid_desc_bk0_n_bk1; return b1_grid_desc_bk0_n_bk1;
} }
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
...@@ -809,26 +809,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -809,26 +809,10 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a, p_b, p_b1, p_c, MRaw,
p_b, NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
p_b1, StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
p_c, BatchStrideB1, BatchStrideC, a_element_op, b_element_op, b1_element_op,
MRaw,
NRaw,
KRaw,
Gemm1NRaw,
Batch,
StrideA,
StrideB,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideB1,
BatchStrideC,
a_element_op,
b_element_op,
b1_element_op,
c_element_op}; c_element_op};
} }
......
...@@ -207,7 +207,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -207,7 +207,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB), return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatCShuffle)); c_block_size * sizeof(FloatCShuffle));
} }
...@@ -234,7 +235,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -234,7 +235,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && Gemm1N % Gemm1NPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
...@@ -472,8 +474,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -472,8 +474,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); const auto a_block_reset_copy_step =
const auto b_block_reset_copy_step = make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
// Only supports LoopScheduler::Default // Only supports LoopScheduler::Default
......
...@@ -1154,11 +1154,11 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1154,11 +1154,11 @@ struct ThreadwiseTensorSliceTransfer_v4
{ {
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
// apply type convert // apply type convert
src_tmp_vector.template AsType<SrcData>()(i) = src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
src_buf[Number<src_offset>{}];
}); });
} }
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
...@@ -1206,7 +1206,8 @@ template <typename SrcData, ...@@ -1206,7 +1206,8 @@ template <typename SrcData,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
index_t DstScalarPerVector, index_t DstScalarPerVector,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic struct ThreadwiseTensorSliceTransfer_StaticToStatic
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -1222,7 +1223,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1222,7 +1223,10 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
"wrong! Not divisible"); "wrong! Not divisible");
} }
template <typename SrcSliceOriginIdx, typename DstSliceOriginIdx, typename SrcBuffer, typename DstBuffer> template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&, const SrcSliceOriginIdx&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
...@@ -1277,7 +1281,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1277,7 +1281,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
}); });
}); });
} }
}; };
} // namespace ck } // namespace ck
...@@ -739,13 +739,15 @@ struct XdlopsGemm ...@@ -739,13 +739,15 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr (!TransposeC) if constexpr(!TransposeC)
{ {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], p_c_thread);
} }
else else
{ {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_b_wave[k], p_a_wave[k], p_c_thread); mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_b_wave[k], p_a_wave[k], p_c_thread);
} }
}); });
} }
......
...@@ -69,7 +69,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -69,7 +69,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
AccDataType v_c; AccDataType v_c;
......
...@@ -163,7 +163,8 @@ struct GeneratorTensor_Diagonal ...@@ -163,7 +163,8 @@ struct GeneratorTensor_Diagonal
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}}; std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
size_t start_dim = dims.size() - NumEffectiveDim; size_t start_dim = dims.size() - NumEffectiveDim;
bool pred = true; bool pred = true;
for (size_t i = start_dim + 1; i < dims.size(); i++) { for(size_t i = start_dim + 1; i < dims.size(); i++)
{
pred &= (dims[start_dim] == dims[i]); pred &= (dims[start_dim] == dims[i]);
} }
return pred ? value : T{0}; return pred ? value : T{0};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment