Commit 32806d5f authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents e70a4d19 d0f355a3
rocm-docs-core>=0.20.0 rocm-docs-core==0.30.2
sphinxcontrib-bibtex==2.6.1 sphinxcontrib-bibtex==2.6.1
...@@ -16,7 +16,7 @@ beautifulsoup4==4.11.2 ...@@ -16,7 +16,7 @@ beautifulsoup4==4.11.2
# via pydata-sphinx-theme # via pydata-sphinx-theme
breathe==4.34.0 breathe==4.34.0
# via rocm-docs-core # via rocm-docs-core
certifi==2022.12.7 certifi==2023.7.22
# via requests # via requests
cffi==1.15.1 cffi==1.15.1
# via # via
...@@ -26,7 +26,7 @@ charset-normalizer==3.1.0 ...@@ -26,7 +26,7 @@ charset-normalizer==3.1.0
# via requests # via requests
click==8.1.3 click==8.1.3
# via sphinx-external-toc # via sphinx-external-toc
cryptography==40.0.2 cryptography==41.0.6
# via pyjwt # via pyjwt
deprecated==1.2.13 deprecated==1.2.13
# via pygithub # via pygithub
...@@ -42,7 +42,7 @@ fastjsonschema==2.18.0 ...@@ -42,7 +42,7 @@ fastjsonschema==2.18.0
# via rocm-docs-core # via rocm-docs-core
gitdb==4.0.10 gitdb==4.0.10
# via gitpython # via gitpython
gitpython==3.1.35 gitpython==3.1.37
# via rocm-docs-core # via rocm-docs-core
idna==3.4 idna==3.4
# via requests # via requests
...@@ -88,9 +88,9 @@ pydata-sphinx-theme==0.13.3 ...@@ -88,9 +88,9 @@ pydata-sphinx-theme==0.13.3
# via # via
# rocm-docs-core # rocm-docs-core
# sphinx-book-theme # sphinx-book-theme
pygithub==1.58.2 pygithub==1.58.1
# via rocm-docs-core # via rocm-docs-core
pygments==2.14.0 pygments==2.15.0
# via # via
# accessible-pygments # accessible-pygments
# pydata-sphinx-theme # pydata-sphinx-theme
...@@ -109,11 +109,11 @@ pyyaml==6.0 ...@@ -109,11 +109,11 @@ pyyaml==6.0
# pybtex # pybtex
# rocm-docs-core # rocm-docs-core
# sphinx-external-toc # sphinx-external-toc
requests==2.28.2 requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.27.0 rocm-docs-core==0.30.2
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
...@@ -141,7 +141,7 @@ sphinx-book-theme==1.0.1 ...@@ -141,7 +141,7 @@ sphinx-book-theme==1.0.1
# via rocm-docs-core # via rocm-docs-core
sphinx-copybutton==0.5.1 sphinx-copybutton==0.5.1
# via rocm-docs-core # via rocm-docs-core
sphinx-design==0.3.0 sphinx-design==0.4.1
# via rocm-docs-core # via rocm-docs-core
sphinx-external-toc==0.3.1 sphinx-external-toc==0.3.1
# via rocm-docs-core # via rocm-docs-core
...@@ -163,7 +163,7 @@ sphinxcontrib-serializinghtml==1.1.5 ...@@ -163,7 +163,7 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx # via sphinx
typing-extensions==4.5.0 typing-extensions==4.5.0
# via pydata-sphinx-theme # via pydata-sphinx-theme
urllib3==1.26.15 urllib3==1.26.18
# via requests # via requests
wrapt==1.15.0 wrapt==1.15.0
# via deprecated # via deprecated
......
===============
Wrapper
===============
-------------------------------------
Description
-------------------------------------
.. note::
The wrapper is under development and its functionality is limited.
CK provides a lightweight wrapper for more complex operations implemented in
the library. It allows indexing of nested layouts using a simple interface
(avoiding complex descriptor transformations) and memory access (using Tensor).
Example:
.. code-block:: c
const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4));
const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8));
const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8);
std::array<ck::index_t, 32> data;
auto tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(&data[0], layout);
for(ck::index_t w = 0; w < size(tensor); w++) {
tensor(w) = w;
}
// slice() == slice(0, -1) (whole dimension)
auto tensor_slice = tensor(ck::wrapper::slice(1, 3), ck::make_tuple(ck::wrapper::slice(), ck::wrapper::slice()));
std::cout << "dims:2,(2,4) strides:2,(1,8)" << std::endl;
for(ck::index_t h = 0; h < ck::wrapper::size<0>(tensor_slice); h++)
{
for(ck::index_t w = 0; w < ck::wrapper::size<1>(tensor_slice); w++)
{
std::cout << tensor_slice(h, w) << " ";
}
std::cout << std::endl;
}
Output::
dims:2,(2,4) strides:2,(1,8)
1 5 9 13 17 21 25 29
2 6 10 14 18 22 26 30
-------------------------------------
Layout
-------------------------------------
.. doxygenstruct:: ck::wrapper::Layout
-------------------------------------
Layout helpers
-------------------------------------
.. doxygenfile:: layout_utils.hpp
-------------------------------------
Tensor
-------------------------------------
.. doxygenstruct:: ck::wrapper::Tensor
-------------------------------------
Tensor helpers
-------------------------------------
.. doxygenfile:: tensor_utils.hpp
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
#include <random>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
...@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for(std::size_t n = 0; n < N; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
ADataType tmp_val; ADataType tmp_val;
// auto a_val = A_nchw(n, c, h, w);
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor_b(tmp_val, a_val); functor_b(tmp_val, a_val);
// functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
scale * tmp_val); scale * tmp_val);
} }
...@@ -62,12 +61,14 @@ int main() ...@@ -62,12 +61,14 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
std::vector<std::size_t> nchw = {4, 2, 1, 8}; std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {4, 1, 8, 2}; std::vector<std::size_t> nhwc = {16, 32, 64, 8};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
float scale = 1.f; float scale = 1.f;
auto i = 0; auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
...@@ -75,7 +76,7 @@ int main() ...@@ -75,7 +76,7 @@ int main()
{ {
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i; (h * nchw[3]) + w] = i;
i++; i = dis(gen);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
...@@ -67,6 +67,8 @@ int main() ...@@ -67,6 +67,8 @@ int main()
float scale = 1.f; float scale = 1.f;
auto i = 0; auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
...@@ -74,7 +76,7 @@ int main() ...@@ -74,7 +76,7 @@ int main()
{ {
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i; (h * nchw[3]) + w] = i;
i++; i = dis(gen);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp)
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 2; constexpr int Rank = 2;
...@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1; ...@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
// inv_std: [M, 1] // inv_std: [M, 1]
// Output shape // Output shape
// dx: [M, N]
// dgamma: [1, N] // dgamma: [1, N]
// dbeta: [1, N] // dbeta: [1, N]
...@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1; ...@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0) // dbeta = reduce_sum(dy, axis=0)
// [CAUSION] // [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension // In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different // dimension, K is reduced dimension Hence, M in this example and
// DeviceNormalizationBwdGammaBetaImpl is different
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl<
DYDataType, DYDataType,
XDataType, XDataType,
...@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterInvarient 8, // MThreadClusterSize
32, // ClusterReduce 32, // KThreadClusterSize
8, // SliceInvarient 4, // MThreadSliceSize
1, // SliceReduce 1, // KThreadSliceSize
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
true, // IsMeanInvStdFastestDimReduced true, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize 4, // DGammaDstVectorSize
1>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
...@@ -96,16 +124,48 @@ int main() ...@@ -96,16 +124,48 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data()); dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data()); mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data());
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths
{N, 1}, // dyStrides
{N, 1}, // xStrides
{0, 1}, // gammaStrides
{1, 0}, // meanStrides
{1, 0}, // invStdStrides
{N, 1}, // dxStrides
{1}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr = auto gamma_beta_argument_ptr =
gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths
...@@ -126,7 +186,8 @@ int main() ...@@ -126,7 +186,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{ {
std::cout << "The runtime parameters are not supported" << std::endl; std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1; return 1;
}; };
...@@ -156,9 +217,11 @@ int main() ...@@ -156,9 +217,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data()); dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp)
add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp) add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp)
...@@ -15,23 +15,58 @@ ...@@ -15,23 +15,58 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// Grouprnorm // Grouprnorm
// kernel: M , K // kernel 1: M , K
// dy: N, H, W, G, C -> N * G, H * W * C
// x: N, H, W, G, C -> N * G, H * W * C
// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C
// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// dx: N, H, W, G, C -> N * G, H * W * C
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
// kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W // dy: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W // x: N, H, W, G, C -> G * C, N * H * W
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 // mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
...@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterInvarient 8, // ClusterInvariant
32, // ClusterReduce 32, // ClusterReduce
8, // SliceInvarient 4, // SliceInvariant
1, // SliceReduce 1, // SliceReduce
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
false, // IsMeanInvStdFastestDimReduced false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize 4, // DGammaDstVectorSize
1>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
...@@ -93,20 +128,55 @@ int main() ...@@ -93,20 +128,55 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data()); dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data()); mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data());
std::vector<ck::index_t> dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; std::vector<ck::index_t> dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<ck::index_t> xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; std::vector<ck::index_t> xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<ck::index_t> gammaStrides = {0, 0, 0, C, 1};
std::vector<ck::index_t> meanStrides = {G, 0, 0, 1, 0}; std::vector<ck::index_t> meanStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> invStdStrides = {G, 0, 0, 1, 0}; std::vector<ck::index_t> invStdStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()};
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths
dyStrides, // dyStrides
xStrides, // xStrides
gammaStrides, // gammaStrides
meanStrides, // meanStrides
invStdStrides, // invStdStrides
dxStrides, // dxStrides
{1, 2, 4}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr = auto gamma_beta_argument_ptr =
...@@ -128,7 +198,8 @@ int main() ...@@ -128,7 +198,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{ {
std::cout << "The runtime parameters are not supported" << std::endl; std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1; return 1;
}; };
...@@ -158,9 +229,11 @@ int main() ...@@ -158,9 +229,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data()); dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -42,6 +42,8 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -42,6 +42,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
# ScaleAdd ScaleAdd Relu # ScaleAdd ScaleAdd Relu
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
constexpr ck::index_t NDimSpatial = 3;
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using OutDataType = ck::half_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
using BiasLayout = ck::tensor_layout::convolution::G_K;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename OutElementOp>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<OutLayout, BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<OutDataType, OutDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8>;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
namespace {
// Use custom implementation to pass two more tensors for post op
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
constexpr ck::index_t NumDs = 2;
const ck::index_t G = out_g_n_k_wos_desc.GetLengths()[0];
const ck::index_t K = out_g_n_k_wos_desc.GetLengths()[2];
// Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW)
std::array<ck::index_t, NDimSpatial + 3> bias_g_k_lengths;
std::array<ck::index_t, NDimSpatial + 3> bias_g_k_strides;
// Fill other lenghts than G,K with 1 and strides with 0
bias_g_k_lengths.fill(1);
bias_g_k_strides.fill(0);
bias_g_k_lengths[0] = G;
bias_g_k_lengths[2] = K;
bias_g_k_strides[0] = K; // stride to G
bias_g_k_strides[2] = 1; // stride to K
const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides);
// y = relu ( alpha1 * conv(x) + alpha2 * z + bias )
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::array<Tensor<OutDataType>, NumDs> d_tensors = {Tensor<OutDataType>(out_g_n_k_wos_desc),
Tensor<OutDataType>(broadcasted_bias_desc)};
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
std::cout << "z_tensor: " << d_tensors[0].mDesc << std::endl;
std::cout << "bias_tensor: " << d_tensors[1].mDesc << std::endl;
// Make sure that we allocated only G * K values for bias
assert(static_cast<ck::index_t>(d_tensors[1].mData.size()) == G * K);
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
d_tensors[0].GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
d_tensors[1].GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.05, 0.05});
d_tensors[0].GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.05, 0.05});
d_tensors[1].GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.05, 0.05});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem z_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize());
DeviceMem bias_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
z_buf.ToDevice(d_tensors[0].mData.data());
bias_buf.ToDevice(d_tensors[1].mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
const std::array<const void*, NumDs> ds = {z_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer()};
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
ds,
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, NumDs>{
e_g_n_k_wos_lengths, bias_g_k_lengths},
std::array<std::array<ck::index_t, NDimSpatial + 3>, NumDs>{
e_g_n_k_wos_strides, bias_g_k_strides},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error("The device op with the specified compilation parameters does "
"not support this convolution problem.");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops() + G * K +
conv_param.GetOutputByte<OutDataType>() / sizeof(OutDataType);
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>() +
G * K * sizeof(OutDataType) + conv_param.GetOutputByte<OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
auto ref_conv =
ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
0, /*Num A Elementwise Tensors*/
0, /*Num B Elementwise Tensors*/
NumDs>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
out_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op,
{},
{},
d_tensors);
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!");
}
return true;
}
} // namespace
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
...@@ -24,7 +24,7 @@ bool run_convnd_fwd_example(int argc, char* argv[]) ...@@ -24,7 +24,7 @@ bool run_convnd_fwd_example(int argc, char* argv[])
// Following shapes are selected to avoid overflow. Expect inf in case of // Following shapes are selected to avoid overflow. Expect inf in case of
// size increase for some elementwise ops. // size increase for some elementwise ops.
ck::utils::conv::ConvParam conv_param{ ck::utils::conv::ConvParam conv_param{
3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; 3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
if(argc == 1) if(argc == 1)
{ {
......
add_example_executable(example_tensor_transform tensor_transform.cpp)
add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
...@@ -26,7 +26,7 @@ inline std::string get_device_name() ...@@ -26,7 +26,7 @@ inline std::string get_device_name()
} }
const std::string raw_name(props.gcnArchName); const std::string raw_name(props.gcnArchName);
// https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
static std::map<std::string, std::string> device_name_map = { static std::map<std::string, std::string> device_name_map = {
{"Ellesmere", "gfx803"}, {"Ellesmere", "gfx803"},
{"Baffin", "gfx803"}, {"Baffin", "gfx803"},
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalizationBwdData : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationBwdDataPtr = std::unique_ptr<DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access // for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i) for(index_t i = 0; i < NumATensor; ++i)
{ {
a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
} }
for(index_t i = 0; i < NumBTensor; ++i) for(index_t i = 0; i < NumBTensor; ++i)
{ {
b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
} }
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
} }
// pointers // pointers
...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
std::array<index_t, NumATensor> a_mz_stride_; std::array<bool, NumATensor> as_mz_consecutive_;
std::array<index_t, NumATensor> a_kz_stride_; std::array<bool, NumATensor> as_kz_consecutive_;
std::array<bool, NumBTensor> bs_nz_consecutive_;
std::array<index_t, NumBTensor> b_nz_stride_; std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<index_t, NumBTensor> b_kz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_nz_stride_; std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// check vector load/store // check vector load/store
{ {
bool all_valid = true; bool valid_as_access = true;
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m =
ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
{ {
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % valid_as_access = false;
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
else });
{ if(!valid_as_access)
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) %
ABlockTransferSrcScalarPerVector ==
0))
{ {
all_valid = false; return false;
}
} }
});
// vector memory access of B: could be on N or BK1 dimension bool valid_bs_access = true;
static_for<0, NumBTensor, 1>{}([&](auto i) { static_for<0, NumBTensor, 1>{}([&](auto i) {
if constexpr(BBlockTransferSrcVectorDim == 1) const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
{ {
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % valid_bs_access = false;
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
else });
{ if(!valid_bs_access)
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector ==
0))
{ {
all_valid = false; return false;
}
} }
});
// check vector load of Ds bool valid_ds_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
all_valid = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{ {
all_valid = false; return false;
} }
if(!all_valid) const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return generate_tuple([&](auto i) { return vec[i]; }, num); return generate_tuple([&](auto i) { return vec[i]; }, num);
}; };
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ... // dimension Ids for M0, M1, ...
...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{}; typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ... // lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks = const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op}
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{}
{ {
// populate pointer, batch stride, desc for Ds // populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
} }
void Print() const void Print() const
...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
index_t a_mz_stride_; bool a_mz_consecutive_;
index_t a_kz_stride_; bool a_kz_consecutive_;
index_t b_nz_stride_; bool b_nz_consecutive_;
index_t b_kz_stride_; bool b_kz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
index_t e_mz_stride_; bool e_nz_consecutive_;
index_t e_nz_stride_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!"); "wrong!");
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
{ const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
if(!(arg.a_mz_stride_ == 1 && const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
{ {
return false; return false;
} }
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of B: could be on N or BK1 dimension const bool valid_b_vector_size =
if constexpr(BBlockTransferSrcVectorDim == 1) arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
{ const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
if(!(arg.b_nz_stride_ == 1 && const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
{ {
return false; return false;
} }
}
else
{
if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool valid_d_access = true;
bool valid_ds_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_d_access = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
if(valid_d_access == false)
{ {
return false; return false;
} }
// vector memory access of E: always on NPerBlock dimension const bool valid_e_vector_size =
if(!(arg.e_nz_stride_ == 1 && arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % // Vector write of E is always on N dimension.
CDEBlockTransferScalarPerVector_NPerBlock == const bool valid_e_access_dim = arg.e_nz_consecutive_;
0)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <sstream>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template <index_t NumDim1, index_t NumDim2>
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
{
if(lengths.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of lengths in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(strides.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of strides in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
{
consecutive_stride *= lengths[dim_idx];
}
else
{
break;
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides) const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i], return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i]); ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
...@@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
}); });
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
...@@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> || is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> || is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> || is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> || is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
is_same_v<DLayout, ctc::G_K>)
{ {
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
...@@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
valid = false; valid = false;
} }
if constexpr(is_same_v<DLayout, ctc::G_K>)
{
// G and K must be the same
if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
{
valid = false;
}
}
else
{
// E and D must have the same shape
for(index_t d = 0; d < NDimSpatial + 3; d++)
{
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
{
valid = false;
}
}
}
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment