Commit e71afee2 authored by Jing Zhang's avatar Jing Zhang
Browse files

add multiD support into batched_gemm_c_permute

parent 85978e02
...@@ -178,14 +178,17 @@ int main(int argc, char* argv[]) ...@@ -178,14 +178,17 @@ int main(int argc, char* argv[])
// do GEM // do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
{},
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
{},
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
batch_count, batch_count,
a_element_op, a_element_op,
......
...@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc ...@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator struct DeviceBatchedGemmCPermute : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A, index_t batch_stride_A,
index_t batch_stride_B, index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount, index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator ...@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
using DeviceBatchedGemmCPermutePtr =
std::unique_ptr<DeviceBatchedGemmCPermute<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment