gemm_dl_int8.cpp 3.67 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#include "common.hpp"
Jianfeng Yan's avatar
Jianfeng Yan committed
5

Chao Liu's avatar
Chao Liu committed
6
#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
Jianfeng Yan's avatar
Jianfeng Yan committed
7
8
9
10
11
12
13
14
15
16

using ADataType   = int8_t;
using BDataType   = int8_t;
using CDataType   = int8_t;
using AccDataType = int32_t;

using ALayout = Col;
using BLayout = Row;
using CLayout = Row;

17
18
19
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
Jianfeng Yan's avatar
Jianfeng Yan committed
20
21
22
23

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

// clang-format off
24
25
26
27
28
29
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
// ######|     AData|     BData|     CData|     AccData| ALayout| BLayout| CLayout|           A|           B|           C|           GEMM| Block|  MPer|  NPer| K0Per| K1|      M1Per|      N1Per|   KPer|  M11N11Thread|  M11N11Thread|     ABlockTransfer|       ABlockTransfer| ABlockTransfer| ABlockTransfer|      ABlockTransfer|     ABlockTransfer|      ABlockTransfer|     BBlockTransfer|       BBlockTransfer| BBlockTransfer| BBlockTransfer|      BBlockTransfer|     BBlockTransfer|      BBlockTransfer|     CThreadTransfer| CThreadTransfer|    CThreadTransfer|
// ######|      Type|      Type|      Type|        Type|        |        |        | Elementwise| Elementwise| Elementwise| Spacialization|  Size| Block| Block| Block|   | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths|  ThreadCluster|      SrcAccess|     SrcVectorTensor|    SrcVectorTensor|     DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths|  ThreadCluster|      SrcAccess|     SrcVectorTensor|    SrcVectorTensor|     DstVectorTensor|        SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######|          |          |          |            |        |        |        |   Operation|   Operation|   Operation|               |      |      |      |      |   |           |           |       |              |              |        K0_M0_M1_K1|          K0_M0_M1_K1|   ArrangeOrder|          Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1|        K0_N0_N1_K1|          K0_N0_N1_K1|   ArrangeOrder|          Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1|               Order|                |                   |
// ######|          |          |          |            |        |        |        |            |            |            |               |      |      |      |      |   |           |           |       |              |              |                   |                     |               |               |                    |                   |                    |                   |                     |               |               |                    |                   |                    |                    |                |                   |
         < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout,  AElementOp,  BElementOp,  CElementOp,    GemmDefault,   256,   128,   128,    16,  4,          4,          4,      1,       S<8, 2>,       S<8, 2>,      S<2, 1, 4, 4>,       S<8, 1, 32, 1>,  S<0, 3, 1, 2>,  S<0, 3, 1, 2>,       S<1, 1, 4, 1>,      S<0, 3, 1, 2>,       S<1, 1, 4, 4>,      S<2, 1, 4, 4>,       S<8, 1, 32, 1>,  S<0, 3, 1, 2>,  S<0, 3, 1, 2>,       S<1, 1, 4, 1>,      S<0, 3, 1, 2>,       S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>,               5,                  4>;
Jianfeng Yan's avatar
Jianfeng Yan committed
30
31
32
// clang-format on

using ReferenceGemmInstance = ck::tensor_operation::host::
33
    ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
Jianfeng Yan's avatar
Jianfeng Yan committed
34

35
#include "run_gemm_example.inc"
Jianfeng Yan's avatar
Jianfeng Yan committed
36

37
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }