Commit c7913947 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

added gemm xdl fp16_int8_fp16

parent 886d9eeb
......@@ -71,3 +71,6 @@ endforeach()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
add_example_executable(example_gemm_xdl_fp16_int8 gemm_xdl_fp16_int8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_int8)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
using ADataType = ck::half_t;
using BDataType = int8_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
static constexpr auto PipelineVer = ck::PipelineVersion::v1;
using ComputeType = ck::half_t;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| |
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
......@@ -367,6 +367,17 @@ void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
template <typename ALayout,
typename BLayout,
......@@ -612,6 +623,21 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instances(op_ptrs);
}
#endif
return op_ptrs;
}
......
......@@ -46,6 +46,8 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_int8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_int8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
......
......@@ -24,6 +24,7 @@ enum struct GemmDataType
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F8_F8, // 4
F16_INT8_F16 // 5
};
#define OP_NAME "gemm"
......@@ -32,7 +33,7 @@ enum struct GemmDataType
static void print_helper_msg()
{
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8)\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8; 5: fp16 & int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
......@@ -175,6 +176,16 @@ int profile_gemm(int argc, char* argv[])
return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_INT8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(Row{}, Row{}, Row{}, F16{}, INT8{}, F32{}, F16{});
}
else if(data_type == GemmDataType::F16_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(Row{}, Col{}, Row{}, F16{}, INT8{}, F32{}, F16{});
}
#endif
#ifdef CK_ENABLE_BF16
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment