Commit d8552699 authored by Chao Liu's avatar Chao Liu
Browse files

fix compilation

parent 40fabdcd
......@@ -7,10 +7,11 @@ using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = BF16;
using D1DataType = BF16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = BF16;
using D1DataType = BF16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;
using ALayout = Row;
using BLayout = Col;
......@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
......
......@@ -6,8 +6,8 @@
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using CDataType = F16; // C matrix doesn't exsit in GPU memory, this is used for host verification
using CShuffleDataType = F32;
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
......@@ -7,10 +6,11 @@ using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F32;
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F32;
using ALayout = Row;
using BLayout = Col;
......@@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
......
......@@ -11,10 +11,11 @@ using ADataType = I4;
using BDataType = I4;
using AccDataType = I32;
using CShuffleDataType = I32;
using D0DataType = I4;
using D1DataType = I4;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I4;
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = I4;
using D1DataType = I4;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I4;
using KernelADataType = I8;
using KernelBDataType = I8;
......@@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
......
......@@ -7,10 +7,11 @@ using ADataType = I8;
using BDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using D0DataType = I8;
using D1DataType = I8;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I8;
using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D0DataType = I8;
using D1DataType = I8;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I8;
using ALayout = Row;
using BLayout = Col;
......@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
......
......@@ -235,6 +235,35 @@ struct AddAddFastGelu
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
{
const float x0_f =
type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<int8_t>(x1_f);
}
};
struct Normalize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment