Commit e9f656fa authored by rocking's avatar rocking
Browse files

[What] Rename F to H

[Why] F and G prepare for welford tensor
parent 421b718d
...@@ -39,7 +39,7 @@ using D0DataType = F16; ...@@ -39,7 +39,7 @@ using D0DataType = F16;
using D1DataType = F16; using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16; using EDataType = F16;
using FDataType = F16; using HDataType = F16;
// Layout // Layout
using ALayout = Row; using ALayout = Row;
...@@ -48,22 +48,22 @@ using D0Layout = Row; ...@@ -48,22 +48,22 @@ using D0Layout = Row;
using D1Layout = Row; using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>; using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row; using ELayout = Row;
using FLayout = Row; using HLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = AddReluAdd; using CDEElementOp = AddReluAdd;
using FElementOp = PassThrough; using HElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| ELayout| FLayout| AData| BData| AccData| CShuffle| DsData| EData| FData| A| B| CDE| F| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| ReduceThreadTransfer| | //######| ALayout| BLayout| DsLayout| ELayout| HLayout| AData| BData| AccData| CShuffle| DsData| EData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| ReduceThreadTransfer| |
//######| | | | | | Type| Type| Type| DataType| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| //######| | | | | | Type| Type| Type| DataType| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock|
< ALayout, BLayout, DsLayout, ELayout, FLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, FDataType, AElementOp, BElementOp, CDEElementOp, FElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4>; < ALayout, BLayout, DsLayout, ELayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4>;
// clang-format on // clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
...@@ -97,7 +97,7 @@ int main() ...@@ -97,7 +97,7 @@ int main()
ck::index_t StrideD0 = 0; ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 1024; ck::index_t StrideD1 = 1024;
ck::index_t StrideE = 1024; ck::index_t StrideE = 1024;
ck::index_t StrideF = 1024; ck::index_t StrideH = 1024;
// TODO - gamma and beta // TODO - gamma and beta
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
...@@ -105,7 +105,7 @@ int main() ...@@ -105,7 +105,7 @@ int main()
Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1)); Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<FDataType> f_m_n(f_host_tensor_descriptor2d(M, N, StrideF, FLayout{})); Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
...@@ -117,7 +117,7 @@ int main() ...@@ -117,7 +117,7 @@ int main()
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem f_device_buf(sizeof(FDataType) * f_m_n.mDesc.GetElementSpaceSize()); DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
...@@ -127,7 +127,7 @@ int main() ...@@ -127,7 +127,7 @@ int main()
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
auto f_element_op = FElementOp{}; auto h_element_op = HElementOp{};
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
...@@ -136,7 +136,7 @@ int main() ...@@ -136,7 +136,7 @@ int main()
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
f_device_buf.GetDeviceBuffer(), h_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
...@@ -144,11 +144,11 @@ int main() ...@@ -144,11 +144,11 @@ int main()
StrideB, StrideB,
{StrideD0, StrideD1}, {StrideD0, StrideD1},
StrideE, StrideE,
StrideF, StrideH,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
f_element_op); h_element_op);
if(!device_op.IsSupportedArgument(argument)) if(!device_op.IsSupportedArgument(argument))
{ {
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp" #include "device_base.hpp"
...@@ -26,10 +25,10 @@ namespace device { ...@@ -26,10 +25,10 @@ namespace device {
// input : B[N, K] // input : B[N, K]
// input : D0[M, N], D1[M, N], ... // input : D0[M, N], D1[M, N], ...
// output : E[M, N] // output : E[M, N]
// output : F[M, N] // output : H[M, N]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// F = layernorm(E) // H = layernorm(E)
// Assume: // Assume:
// D0, D1, ... and E have the same layout // D0, D1, ... and E have the same layout
// Calculate mean & variance along N dimension in layernorm(E) // Calculate mean & variance along N dimension in layernorm(E)
...@@ -37,7 +36,7 @@ template <typename ALayout, ...@@ -37,7 +36,7 @@ template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
typename FLayout, typename HLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
...@@ -48,7 +47,7 @@ template <typename ALayout, ...@@ -48,7 +47,7 @@ template <typename ALayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
typename FElementwiseOperation, typename HElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -163,7 +162,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -163,7 +162,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1)); using EGridDesc_M_N = decltype(MakeGridDescriptor_M_N<ELayout>(1, 1, 1));
using FGridDesc_M_N = decltype(MakeGridDescriptor_M_N<FLayout>(1, 1, 1)); using HGridDesc_M_N = decltype(MakeGridDescriptor_M_N<HLayout>(1, 1, 1));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -180,11 +179,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -180,11 +179,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
index_t StrideF, index_t StrideH,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
BElementwiseOperation f_element_op) BElementwiseOperation h_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
...@@ -194,11 +193,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -194,11 +193,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
f_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<FLayout>(MRaw, NRaw, StrideF)}, h_grid_desc_m_n_{DeviceOp::MakeGridDescriptor_M_N<HLayout>(MRaw, NRaw, StrideH)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
f_element_op_{f_element_op} h_element_op_{h_element_op}
{ {
// TODO // TODO
} }
...@@ -210,7 +209,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -210,7 +209,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
std::cout << "F[M, N]: " << f_grid_desc_m_n_ << std::endl; std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl;
} }
// private: // private:
...@@ -227,7 +226,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -227,7 +226,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
FGridDesc_M_N f_grid_desc_m_n_; HGridDesc_M_N h_grid_desc_m_n_;
// TODO - tensor descriptors for block/thread-wise copy // TODO - tensor descriptors for block/thread-wise copy
// TODO - block-to-e-tile map // TODO - block-to-e-tile map
...@@ -236,7 +235,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -236,7 +235,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
FElementwiseOperation f_element_op_; HElementwiseOperation h_element_op_;
}; };
// Invoker // Invoker
...@@ -286,11 +285,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -286,11 +285,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
index_t StrideF, index_t StrideH,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
FElementwiseOperation f_element_op) HElementwiseOperation h_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -304,11 +303,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -304,11 +303,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
StrideB, StrideB,
StrideDs, StrideDs,
StrideE, StrideE,
StrideF, StrideH,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
f_element_op}; h_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -326,11 +325,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -326,11 +325,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideE, index_t StrideE,
index_t StrideF, index_t StrideH,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
FElementwiseOperation f_element_op) HElementwiseOperation h_element_op)
{ {
return std::make_unique<Argument>(p_a, return std::make_unique<Argument>(p_a,
p_b, p_b,
...@@ -344,11 +343,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -344,11 +343,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
StrideB, StrideB,
StrideDs, StrideDs,
StrideE, StrideE,
StrideF, StrideH,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
f_element_op); h_element_op);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment