Commit 162d0305 authored by root's avatar root
Browse files

add client example

parent d8ab41d5
...@@ -2,6 +2,9 @@ if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf ...@@ -2,6 +2,9 @@ if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf
add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp) add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_multiply_add_fastgelu_xdl_bf16_i8 gemm_multiply_add_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_multiply_add_fastgelu_xdl_bf16_i8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp) add_executable(client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) target_link_libraries(client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
......
...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> && if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, bhalf_t> && is_same_v<D1DataType, bhalf_t> && is_same_v<D0DataType, bhalf_t> && is_same_v<D1DataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>) is_same_v<EDataType, bhalf_t>)
{ {
...@@ -77,7 +77,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -77,7 +77,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> && is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( add_device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_multiply_add_fastgelu_mnkpadding_instances(
op_ptrs); op_ptrs);
} }
} }
......
...@@ -14,9 +14,9 @@ namespace tensor_operation { ...@@ -14,9 +14,9 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using I8 = int8_t; using I8 = int8_t;
using BF16 = bhalf_t; using BF16 = bhalf_t;
using F32 = float; using F32 = float;
using Row = tensor_layout::gemm::RowMajor; using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor; using Col = tensor_layout::gemm::ColumnMajor;
...@@ -24,7 +24,7 @@ using Col = tensor_layout::gemm::ColumnMajor; ...@@ -24,7 +24,7 @@ using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is> template <index_t... Is>
using S = Sequence<Is...>; using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough; using PassThrough = element_wise::PassThrough;
using MultiplyAddFastGelu = element_wise::MultiplyAddFastGelu; using MultiplyAddFastGelu = element_wise::MultiplyAddFastGelu;
static constexpr auto GemmDefault = GemmSpecialization::Default; static constexpr auto GemmDefault = GemmSpecialization::Default;
...@@ -37,11 +37,7 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; ...@@ -37,11 +37,7 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
using DsLayout = ck::Tuple<Row, Row>; using DsLayout = ck::Tuple<Row, Row>;
template < template <typename DsDType, typename CElementwiseOp, GemmSpecialization GemmSpec>
typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec
>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off // clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
...@@ -60,12 +56,10 @@ using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = s ...@@ -60,12 +56,10 @@ using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = s
// clang-format on // clang-format on
>; >;
template < template <typename DsDType,
typename DsDType, typename CElementwiseOp,
typename CElementwiseOp, GemmSpecialization GemmSpec,
GemmSpecialization GemmSpec, BlockGemmPipelineScheduler BlkGemmPipeSched>
BlockGemmPipelineScheduler BlkGemmPipeSched
>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple<
// clang-format off // clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment