gemm_bf16.cpp 1.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"

#include "ck/library/utility/check_err.hpp"
20
21
22
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
23
24
25
26
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

#include "test/gemm/gemm_util.hpp"

27
28
29
30
using ADataType   = ck::bhalf_t;
using BDataType   = ck::bhalf_t;
using CDataType   = ck::bhalf_t;
using AccDataType = float;
31

32
#include "run_gemm_test.inc"
33

34
int main() { return run_gemm_test(); }