"example/vscode:/vscode.git/clone" did not exist on "970fa3e92ec4e67cfbfe1b0428e84870663ab8cd"
Commit 162d0305 authored by root's avatar root
Browse files

add client example

parent d8ab41d5
......@@ -2,6 +2,9 @@ if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf
add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_multiply_add_fastgelu_xdl_bf16_i8 gemm_multiply_add_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_multiply_add_fastgelu_xdl_bf16_i8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
......
......@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, bhalf_t> && is_same_v<D1DataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
{
......@@ -77,7 +77,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
add_device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_multiply_add_fastgelu_mnkpadding_instances(
op_ptrs);
}
}
......
......@@ -37,11 +37,7 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
using DsLayout = ck::Tuple<Row, Row>;
template <
typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec
>
template <typename DsDType, typename CElementwiseOp, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
......@@ -60,12 +56,10 @@ using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = s
// clang-format on
>;
template <
typename DsDType,
template <typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched
>
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment