Commit 2fc2a189 authored by rocking's avatar rocking
Browse files

Refine naming

parent 3bd10f4e
......@@ -57,7 +57,7 @@ using UnarySquareElementOp =
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp =
using DxsGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
......@@ -70,7 +70,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -274,10 +274,10 @@ int main()
reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm
auto layerNorm = DeviceNormalizeInstance{};
auto layerNorm_invoker_ptr = layerNorm.MakeInvokerPointer();
auto layerNorm_argument =
layerNorm.MakeArgumentPointer(c_device_buf.GetDeviceBuffer(),
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker_ptr = normalize.MakeInvokerPointer();
auto normalize_argument =
normalize.MakeArgumentPointer(c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
......@@ -292,7 +292,7 @@ int main()
{StrideC, 1},
NormalizeFunctor{});
if(!layerNorm.IsSupportedArgument(layerNorm_argument.get()))
if(!normalize.IsSupportedArgument(normalize_argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"Device5AryElementwise_Xdl_CShuffle instance, exiting!");
......@@ -300,7 +300,7 @@ int main()
// run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
layerNorm_invoker_ptr->Run(layerNorm_argument.get(), StreamConfig{nullptr, time_kernel});
normalize_invoker_ptr->Run(normalize_argument.get(), StreamConfig{nullptr, time_kernel});
bool pass = true;
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment