Commit e71afee2 authored by Jing Zhang's avatar Jing Zhang
Browse files

add multiD support into batched_gemm_c_permute

parent 85978e02
......@@ -178,14 +178,17 @@ int main(int argc, char* argv[])
// do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
{},
batch_stride_A,
batch_stride_B,
{},
batched_gemm_c_permute_desc,
batch_count,
a_element_op,
......
......@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc
template <typename ALayout,
typename BLayout,
typename DELayout,
typename DLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A,
index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op,
......@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
using DeviceBatchedGemmCPermutePtr =
std::unique_ptr<DeviceBatchedGemmCPermute<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment