Commit 1675a341 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 6f8858bf
...@@ -72,7 +72,7 @@ using CDEElementOp = MultiplyMultiply; ...@@ -72,7 +72,7 @@ using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off // clang-format off
< Row, Col, DsLayout, ELayout, < Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
...@@ -134,7 +134,8 @@ int main(int argc, char* argv[]) ...@@ -134,7 +134,8 @@ int main(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); printf(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
exit(0); exit(0);
} }
......
...@@ -60,7 +60,7 @@ static constexpr ck::index_t Scale_Block_N = 128; ...@@ -60,7 +60,7 @@ static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Scale_Block_K = 128;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
// clang-format off // clang-format off
<Row, Col, DsLayout, ELayout, <Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
......
...@@ -96,8 +96,6 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator ...@@ -96,8 +96,6 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -70,16 +70,16 @@ template <typename ALayout, ...@@ -70,16 +70,16 @@ template <typename ALayout,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayout, struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -192,12 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo ...@@ -192,12 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString( hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
hipMemsetAsync(arg_.p_c_grid, 0,
0, arg_.M * arg_.N * sizeof(CDataType),
arg_.M * arg_.N * sizeof(CDataType), stream_config.stream_id_));
stream_config.stream_id_));
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment