Commit a41f5481 authored by rocking's avatar rocking
Browse files

1. Fix coding style

2. Use DeviceGemm_Xdl_CShuffle instead of deprecated DeviceGemmXdl_C_Shuffle
parent 680cfaa7
......@@ -15,7 +15,7 @@
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
......@@ -50,19 +50,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle<
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
BDataType, // BDataType
CDataType, // CDataType
AccDataType, // AccDataType
CDataType, // CShuffleDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
PassThrough, // AElementwiseOperation
PassThrough, // BElementwiseOperation
PassThrough, // CElementwiseOperation
GemmDefault, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
......@@ -89,7 +93,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
......@@ -149,7 +153,7 @@ using DeviceReduceSumInstance =
1,
1>;
struct Sub_Exp
struct SubExp
{
__host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst,
const EltwiseComputeDataType& src1,
......@@ -174,7 +178,7 @@ using DeviceElementwiseSubExpInstance =
CDataType,
CDataType,
EltwiseComputeDataType,
Sub_Exp,
SubExp,
256,
8>;
......@@ -412,7 +416,7 @@ int main(int argc, char* argv[])
{StrideC, 1},
{0, 1},
{StrideC, 1},
Sub_Exp{});
SubExp{});
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{
......@@ -515,8 +519,8 @@ int main(int argc, char* argv[])
Tensor<CDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Sub_Exp,
0>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
SubExp,
0>(host_exp_m_n, c_m_n, c_n_max, M, N, SubExp{});
host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
......
......@@ -40,7 +40,7 @@ template <typename ADataType,
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_M0 =
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
......@@ -76,7 +76,7 @@ struct GridwiseBinaryElementwise_1D
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_M0),
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
......@@ -88,7 +88,7 @@ struct GridwiseBinaryElementwise_1D
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_M0),
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
......@@ -99,7 +99,7 @@ struct GridwiseBinaryElementwise_1D
auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType,
decltype(thread_desc_M0),
decltype(thread_desc_m0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
......@@ -122,19 +122,19 @@ struct GridwiseBinaryElementwise_1D
{
// read and process ScalarPerVector elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_M0, make_tuple(I0), a_thread_buf);
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_M0, make_tuple(I0), b_thread_buf);
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_M0.CalculateOffset(make_tuple(m));
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});
c_global_write.Run(thread_desc_M0,
c_global_write.Run(thread_desc_m0,
make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf,
c_grid_desc_m0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment