"...amoeba/platforms/cuda-old/src/kernels/amoebaCudaKernels.h" did not exist on "0dd63d02cb12e0c2792d84de9ff6ea5df4493ff0"
asm_flatmm_a8w8_blockscale.cpp 2.73 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// SPDX-License-Identifier: MIT
 
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "aiter_hip_common.h"
#include "hip_float8.h"

struct __attribute__((packed)) KernelArgs
{
    const void* a_ptr;  // [m, k]
    const void* b_ptr;  // [n, k] -> [n/128, k*128]
    const void* c_ptr;  // 
    const void* sa_ptr; // [k/128, m]
    const void* sb_ptr; // [k/128, n/128]
    void* d_ptr;        // 
    void* d_f16_ptr;    // [m, n]
    void* dbg_int_ptr;
    void* dbg_fp8_ptr;
    void* dbg_f16_ptr;
    void* dbg_fp32_ptr;

    int hidden_size;       // K
    int intermediate_size; // N
    int num_tokens;        // M

    int num_experts;
    int topk;
    int stride_token;
};

using namespace hip_fp8_impl;
torch::Tensor flatmm_a8w8_blockscale_asm(
    torch::Tensor &XQ,      // [M, K]
    torch::Tensor &WQ,      // [N, K] -> [N/128, K*128]
    torch::Tensor &x_scale, // [K/128, M]
    torch::Tensor &w_scale, // [K/128, N/128]
    torch::Tensor &out      // Out:[M, N] fp16
)
{
    constexpr int TileM = 128;
    constexpr int TileN = 256;
    constexpr int TileK = 128;

    int m = XQ.size(0);
    int n = out.size(1);
    int k = XQ.size(1);

    TORCH_CHECK(out.dtype() == torch::ScalarType::Half,
                "flatmm a8w8 blockscale asm only support Half output now!");
    TORCH_CHECK(n % TileN == 0 && k % TileK == 0, 
                "flatmm a8w8 blockscale asm only suuport 128x256x128 tile now!");

    KernelArgs args;
    size_t arg_size = sizeof(args);

    args.a_ptr = (void *)XQ.data_ptr();
    args.b_ptr = (void *)WQ.data_ptr();
    args.c_ptr = nullptr;
    args.sa_ptr = (void *)x_scale.data_ptr();
    args.sb_ptr = (void *)w_scale.data_ptr();
    args.d_ptr = nullptr;
    args.d_f16_ptr = (void *)out.data_ptr();

    args.num_tokens = m;
    args.intermediate_size = n;
    args.hidden_size = k;

    const at::cuda::OptionalCUDAGuard device_guard(device_of(XQ));
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    AiterAsmKernel *impl_ptr = nullptr;
    static AiterAsmKernel impl_kenrel("flatmm_uk_gfx9_f16f8_128x256x128_1x4x1_16x16x32", "flatmm_uk_gfx9_f16f8_128x256x128_1x4x1_16x16x32.co");
    impl_ptr = &impl_kenrel;

    int gdx = (n + TileN - 1) / TileN;
    int gdy = (m + TileM - 1) / TileM;

    impl_ptr->launch_kernel({&args,
                             &arg_size,
                             gdx,   // gdx
                             gdy,   // gdy
                             1,     // gdz
                             256,   // bdx: 4 wv64
                             1,     // bdy
                             1,     // bdz
                             stream});                                 

    return out;
}