fp16int8_gemm_wmma.cpp 3.46 KB
Newer Older
aska-0096's avatar
aska-0096 committed
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp"

aska-0096's avatar
aska-0096 committed
8
9
10
11
12
13
14
15
16
// Implementation follows the paper:
// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run:
// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022.
// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to
// unsigned.

// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType
// The HostRef  is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType

aska-0096's avatar
aska-0096 committed
17
using ADataType        = ck::half_t;
aska-0096's avatar
aska-0096 committed
18
19
using QuantDataType    = int8_t;
using BDataType        = uint8_t;
aska-0096's avatar
aska-0096 committed
20
21
using ScaleDataType    = ck::half_t;
using AccDataType      = float;
aska-0096's avatar
aska-0096 committed
22
using CShuffleDataType = ck::half_t;
aska-0096's avatar
aska-0096 committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
using CDataType        = ck::half_t;

using ALayout = Row;
using BLayout = Col;
using CLayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_CShuffle
         < ALayout,             
           BLayout,             
           CLayout,             
           ADataType, 
           BDataType,
           ScaleDataType,
           CDataType, 
           AccDataType, 
           CShuffleDataType,  
           AElementOp,  
           BElementOp,  
           CElementOp,    
           GemmDefault, 
aska-0096's avatar
aska-0096 committed
50
           1,           // Prefetch stage
aska-0096's avatar
aska-0096 committed
51
52
           128,         // BlockSize
           128,         // MPerBlock
aska-0096's avatar
aska-0096 committed
53
           128,          // NPerBlock
aska-0096's avatar
aska-0096 committed
54
55
56
57
58
           64,          // KPerBlock
           8,           // K1
           16,          // MPerWmma
           16,          // NPerWmma
           4,           // M-Repeat // M-PerWmma / M-Repeat = M-Wave
aska-0096's avatar
aska-0096 committed
59
           4,           // N-Repeat // N-PerWmma / N-Repeat = N-Wave
aska-0096's avatar
aska-0096 committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
           S<4, 32, 1>,     
           S<1, 0, 2>,     
           S<1, 0, 2>,              
           2,              
           8,              
           8,      
           true,     
           S<4, 32, 1>,     
           S<1, 0, 2>,     
           S<1, 0, 2>,             
           2,              
           8,              
           8,      
           true,           
           1,           // C shuffle (M Repeat) Per store
           1,           // C shuffle (N Repeat) Per store
           S<1, 32, 1,  4>,               
           8>;
// clang-format on

aska-0096's avatar
aska-0096 committed
80
using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm<ADataType,
aska-0096's avatar
aska-0096 committed
81
                                                                               QuantDataType,
aska-0096's avatar
aska-0096 committed
82
83
84
85
86
87
                                                                               ScaleDataType,
                                                                               CDataType,
                                                                               AccDataType,
                                                                               AElementOp,
                                                                               BElementOp,
                                                                               CElementOp>;
aska-0096's avatar
aska-0096 committed
88
89
90
91

#include "run_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }