Commit 2af456af authored by coderfeli's avatar coderfeli
Browse files

fix

parent e493ab00
......@@ -69,24 +69,21 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// index_t CShuffleMXdlPerWavePerShuffle,
// index_t CShuffleNXdlPerWavePerShuffle,
// typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
// typename CDEShuffleBlockTransferScalarPerVectors,
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
<Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
128, 128, 128,
16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on
int main(int argc, char* argv[])
......@@ -235,7 +232,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......
......@@ -21,10 +21,9 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins
PassThrough,
MultiplyMultiply>>>& instances)
{
// add_device_operation_instances(
// instances,
// device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1<GemmKPadding>{});
(void)instances;
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1<GemmKPadding>{});
}
} // namespace instance
......
......@@ -21,11 +21,9 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins
PassThrough,
MultiplyMultiply>>>& instances)
{
// add_device_operation_instances(
// instances,
// device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2<GemmKPadding>{});
(void)instances;
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2<GemmKPadding>{});
}
} // namespace instance
......
......@@ -21,12 +21,10 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kp
PassThrough,
MultiplyMultiply>>>& instances)
{
// add_device_operation_instances(
// instances,
// device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part1<
// GemmKPadding>{});
(void)instances;
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part1<
GemmKPadding>{});
}
} // namespace instance
......
......@@ -21,12 +21,10 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kp
PassThrough,
MultiplyMultiply>>>& instances)
{
// add_device_operation_instances(
// instances,
// device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part2<
// GemmKPadding>{});
(void)instances;
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part2<
GemmKPadding>{});
}
} // namespace instance
......
......@@ -21,12 +21,10 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_i
PassThrough,
MultiplyMultiply>>>& instances)
{
// add_device_operation_instances(
// instances,
// device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances<Intrawave,
// GemmKPadding>{});
(void)instances;
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances<Intrawave,
GemmKPadding>{});
}
} // namespace instance
......
......@@ -21,12 +21,10 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
PassThrough,
MultiplyMultiply>>>& instances)
{
// add_device_operation_instances(
// instances,
// device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances<Interwave,
// GemmKPadding>{});
(void)instances;
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances<Interwave,
GemmKPadding>{});
}
} // namespace instance
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment