fp16int8_gemm_wmma.cpp 3.52 KB
Newer Older
aska-0096's avatar
aska-0096 committed
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp"

aska-0096's avatar
aska-0096 committed
8
9
10
11
12
13
14
15
16
// Implementation follows the paper:
// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run:
// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022.
// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to
// unsigned.

// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType
// The HostRef  is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType

aska-0096's avatar
aska-0096 committed
17
// TODO: Current implementation consume more VGPR than expected.
aska-0096's avatar
aska-0096 committed
18

aska-0096's avatar
aska-0096 committed
19
using ADataType        = ck::half_t;
aska-0096's avatar
aska-0096 committed
20
21
using QuantDataType    = int8_t;
using BDataType        = uint8_t;
aska-0096's avatar
aska-0096 committed
22
23
using ScaleDataType    = ck::half_t;
using AccDataType      = float;
aska-0096's avatar
aska-0096 committed
24
using CShuffleDataType = ck::half_t;
aska-0096's avatar
aska-0096 committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
using CDataType        = ck::half_t;

using ALayout = Row;
using BLayout = Col;
using CLayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_CShuffle
         < ALayout,             
           BLayout,             
           CLayout,             
           ADataType, 
           BDataType,
           ScaleDataType,
           CDataType, 
           AccDataType, 
           CShuffleDataType,  
           AElementOp,  
           BElementOp,  
           CElementOp,    
           GemmDefault, 
aska-0096's avatar
aska-0096 committed
52
           1,           // Prefetch stage
aska-0096's avatar
aska-0096 committed
53
           128,         // BlockSize
aska-0096's avatar
aska-0096 committed
54
55
           64,          // MPerBlock
           128,         // NPerBlock
aska-0096's avatar
aska-0096 committed
56
57
58
59
           64,          // KPerBlock
           8,           // K1
           16,          // MPerWmma
           16,          // NPerWmma
aska-0096's avatar
aska-0096 committed
60
           2,           // M-Repeat // M-PerWmma / M-Repeat = M-Wave
aska-0096's avatar
aska-0096 committed
61
           4,           // N-Repeat // N-PerWmma / N-Repeat = N-Wave
aska-0096's avatar
aska-0096 committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
           S<4, 32, 1>,     
           S<1, 0, 2>,     
           S<1, 0, 2>,              
           2,              
           8,              
           8,      
           true,     
           S<4, 32, 1>,     
           S<1, 0, 2>,     
           S<1, 0, 2>,             
           2,              
           8,              
           8,      
           true,           
           1,           // C shuffle (M Repeat) Per store
           1,           // C shuffle (N Repeat) Per store
           S<1, 32, 1,  4>,               
           8>;
// clang-format on

aska-0096's avatar
aska-0096 committed
82
using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm<ADataType,
aska-0096's avatar
aska-0096 committed
83
                                                                               QuantDataType,
aska-0096's avatar
aska-0096 committed
84
85
86
87
88
89
                                                                               ScaleDataType,
                                                                               CDataType,
                                                                               AccDataType,
                                                                               AElementOp,
                                                                               BElementOp,
                                                                               CElementOp>;
aska-0096's avatar
aska-0096 committed
90
91
92
93

#include "run_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }