"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "c254e5abd2b01b9d5a2ba3fe4531e178623396d0"
Commit a41f5481 authored by rocking's avatar rocking
Browse files

1. Fix coding style

2. Use DeviceGemm_Xdl_CShuffle instead of deprecated DeviceGemmXdl_C_Shuffle
parent 680cfaa7
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -50,19 +50,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; ...@@ -50,19 +50,23 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType ADataType, // ADataType
BDataType, // BDataType BDataType, // BDataType
CDataType, // CDataType CDataType, // CDataType
AccDataType, // AccDataType AccDataType, // AccDataType
CDataType, // CShuffleDataType CDataType, // CShuffleDataType
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
PassThrough, // AElementwiseOperation PassThrough, // AElementwiseOperation
PassThrough, // BElementwiseOperation PassThrough, // BElementwiseOperation
PassThrough, // CElementwiseOperation PassThrough, // CElementwiseOperation
GemmDefault, // GemmSpec
1, // NumGemmKPrefetchStage
256, // BlockSize 256, // BlockSize
256, // MPerBlock 256, // MPerBlock
128, // NPerBlock 128, // NPerBlock
...@@ -89,7 +93,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle ...@@ -89,7 +93,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
true, // BBlockLdsAddExtraN true, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
...@@ -149,7 +153,7 @@ using DeviceReduceSumInstance = ...@@ -149,7 +153,7 @@ using DeviceReduceSumInstance =
1, 1,
1>; 1>;
struct Sub_Exp struct SubExp
{ {
__host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst, __host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst,
const EltwiseComputeDataType& src1, const EltwiseComputeDataType& src1,
...@@ -174,7 +178,7 @@ using DeviceElementwiseSubExpInstance = ...@@ -174,7 +178,7 @@ using DeviceElementwiseSubExpInstance =
CDataType, CDataType,
CDataType, CDataType,
EltwiseComputeDataType, EltwiseComputeDataType,
Sub_Exp, SubExp,
256, 256,
8>; 8>;
...@@ -412,7 +416,7 @@ int main(int argc, char* argv[]) ...@@ -412,7 +416,7 @@ int main(int argc, char* argv[])
{StrideC, 1}, {StrideC, 1},
{0, 1}, {0, 1},
{StrideC, 1}, {StrideC, 1},
Sub_Exp{}); SubExp{});
if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get())) if(!broadcastSubExp.IsSupportedArgument(broadcastSubExp_argument_ptr.get()))
{ {
...@@ -515,8 +519,8 @@ int main(int argc, char* argv[]) ...@@ -515,8 +519,8 @@ int main(int argc, char* argv[])
Tensor<CDataType>, Tensor<CDataType>,
Tensor<CDataType>, Tensor<CDataType>,
EltwiseComputeDataType, EltwiseComputeDataType,
Sub_Exp, SubExp,
0>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{}); 0>(host_exp_m_n, c_m_n, c_n_max, M, N, SubExp{});
host_reduce_sum.Run(1, // alpha host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()), reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
......
...@@ -40,7 +40,7 @@ template <typename ADataType, ...@@ -40,7 +40,7 @@ template <typename ADataType,
struct GridwiseBinaryElementwise_1D struct GridwiseBinaryElementwise_1D
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_M0 = static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
...@@ -76,7 +76,7 @@ struct GridwiseBinaryElementwise_1D ...@@ -76,7 +76,7 @@ struct GridwiseBinaryElementwise_1D
ThreadwiseTensorSliceTransfer_v2<ADataType, ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, GridDesc_M0,
decltype(thread_desc_M0), decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
...@@ -88,7 +88,7 @@ struct GridwiseBinaryElementwise_1D ...@@ -88,7 +88,7 @@ struct GridwiseBinaryElementwise_1D
ThreadwiseTensorSliceTransfer_v2<BDataType, ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, GridDesc_M0,
decltype(thread_desc_M0), decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
...@@ -99,7 +99,7 @@ struct GridwiseBinaryElementwise_1D ...@@ -99,7 +99,7 @@ struct GridwiseBinaryElementwise_1D
auto c_global_write = auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType, ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType, CDataType,
decltype(thread_desc_M0), decltype(thread_desc_m0),
GridDesc_M0, GridDesc_M0,
PassThrough, PassThrough,
Sequence<ScalarPerVector>, // SliceLengths Sequence<ScalarPerVector>, // SliceLengths
...@@ -122,19 +122,19 @@ struct GridwiseBinaryElementwise_1D ...@@ -122,19 +122,19 @@ struct GridwiseBinaryElementwise_1D
{ {
// read and process ScalarPerVector elements // read and process ScalarPerVector elements
a_global_load.Run( a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_M0, make_tuple(I0), a_thread_buf); a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);
b_global_load.Run( b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_M0, make_tuple(I0), b_thread_buf); b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) { static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_M0.CalculateOffset(make_tuple(m)); constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}), functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{})); b_thread_buf(Number<offset>{}));
}); });
c_global_write.Run(thread_desc_M0, c_global_write.Run(thread_desc_m0,
make_tuple(I0), // SrcSliceOriginIdx make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf, c_thread_buf,
c_grid_desc_m0, c_grid_desc_m0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment