Commit 162d0305 authored by root's avatar root
Browse files

add client example

parent d8ab41d5
......@@ -2,6 +2,9 @@ if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf
add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_multiply_add_fastgelu_xdl_bf16_i8 gemm_multiply_add_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_multiply_add_fastgelu_xdl_bf16_i8 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
......
......@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, bhalf_t> && is_same_v<D1DataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
{
......@@ -77,7 +77,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
add_device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_multiply_add_fastgelu_mnkpadding_instances(
op_ptrs);
}
}
......
......@@ -14,9 +14,9 @@ namespace tensor_operation {
namespace device {
namespace instance {
using I8 = int8_t;
using I8 = int8_t;
using BF16 = bhalf_t;
using F32 = float;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
......@@ -24,7 +24,7 @@ using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using PassThrough = element_wise::PassThrough;
using MultiplyAddFastGelu = element_wise::MultiplyAddFastGelu;
static constexpr auto GemmDefault = GemmSpecialization::Default;
......@@ -37,11 +37,7 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
using DsLayout = ck::Tuple<Row, Row>;
template <
typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec
>
template <typename DsDType, typename CElementwiseOp, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
......@@ -60,12 +56,10 @@ using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_comp_instances = s
// clang-format on
>;
template <
typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched
>
template <typename DsDType,
typename CElementwiseOp,
GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched>
using device_gemm_xdl_universal_multi_d_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment