Commit 2564c493 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into fused-gemm

parents 000eefbf 10b3278b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// FIXME: DeviceGemmReduce type need to well define the problem
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename RsDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation>
struct DeviceGemmMultipleDMultipleR : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumRTensor = RsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::array<void*, NumRTensor> p_rs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
QsElementwiseOperation qs_element_op,
RsElementwiseOperation rs_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename RsDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename QsElementwiseOperation,
typename RsElementwiseOperation>
using DeviceGemmMultipleDMultipleRPtr =
std::unique_ptr<DeviceGemmMultipleDMultipleR<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
RsDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
QsElementwiseOperation,
RsElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -130,6 +130,35 @@ struct AddHardswishAdd
}
};
// C = A * B
// E = C + D0 + D1
struct AddAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
// Only support floating so far
static_assert(is_same<E, half_t>::value || is_same<E, float>::value ||
is_same<E, double>::value,
"Data type is not supported by this operation!");
static_assert(is_same<C, half_t>::value || is_same<C, float>::value ||
is_same<C, double>::value,
"Data type is not supported by this operation!");
static_assert(is_same<D0, half_t>::value || is_same<D0, float>::value ||
is_same<D0, double>::value,
"Data type is not supported by this operation!");
static_assert(is_same<D1, half_t>::value || is_same<D1, float>::value ||
is_same<D1, double>::value,
"Data type is not supported by this operation!");
const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
e = type_convert<E>(y);
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
......
......@@ -1192,6 +1192,10 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
}
__device__ void SetSrcCoord(const Index& src_ref_idx)
{
src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx);
}
private:
SrcCoord src_ref_coord_;
......
......@@ -18,5 +18,15 @@ __device__ void block_sync_lds()
__syncthreads();
#endif
}
__device__ void s_nop()
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
#else
__builtin_amdgcn_sched_barrier(0);
#endif
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment