"torchvision/csrc/vscode:/vscode.git/clone" did not exist on "6e639d3e49371a509235201cb3f335b1c5cac0e3"
Commit cbcc844e authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents 29deceb6 1f306024
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp" #include "common.hpp"
...@@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: ...@@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
#endif
\ No newline at end of file
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp" #include "common.hpp"
...@@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp ...@@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
#include "run_gemm_add_add_fastgelu_example.inc" #include "run_gemm_add_add_fastgelu_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); }
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#define BUILD_INT4_EXAMPLE #define BUILD_INT4_EXAMPLE
...@@ -24,3 +22,4 @@ using RsDataType = ck::Tuple<R0DataType>; ...@@ -24,3 +22,4 @@ using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc" #include "run_convnd_fwd_max_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); }
#endif
...@@ -299,8 +299,8 @@ int main(int argc, char* argv[]) ...@@ -299,8 +299,8 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i); problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128 + 64 * i); problem_size.Ks.push_back(128);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -300,8 +300,8 @@ int main(int argc, char* argv[]) ...@@ -300,8 +300,8 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i); problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128 + 64 * i); problem_size.Ks.push_back(128);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -272,15 +272,14 @@ int main(int argc, char* argv[]) ...@@ -272,15 +272,14 @@ int main(int argc, char* argv[])
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>(); auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>(); auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
ReduceAccDataType d0_val = 0;
ReduceAccDataType d1_val = 0;
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto c_val = auto c_val =
ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n)); ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
ReduceAccDataType d0_val;
ReduceAccDataType d1_val;
UnaryIdenticElementOp{}(d0_val, c_val); UnaryIdenticElementOp{}(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val); UnarySquareElementOp{}(d1_val, c_val);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp" #include "common.hpp"
...@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; ...@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_example.inc" #include "run_grouped_conv_fwd_bias_relu_add_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
#endif
...@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1 Gemm1
*/ */
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); ...@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif #endif
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
#endif
...@@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
......
...@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(config.time_kernel) if(config.time_kernel)
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -42,7 +42,7 @@ using AElementOp = PassThrough; ...@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::KPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off // clang-format off
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#define DIRECT_LOAD 1
#if DIRECT_LOAD
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp"
#endif
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
#if DIRECT_LOAD
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle_LdsDirectLoad
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>;
// clang-format on
#else
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
#endif
#include "run_splitK_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
...@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); ...@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif #endif
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
#endif
...@@ -5,4 +5,6 @@ add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permu ...@@ -5,4 +5,6 @@ add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permu
add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
add_example_executable(example_elementwise_permute elementwise_permute.cpp) add_example_executable(example_elementwise_permute elementwise_permute.cpp)
add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942"))
add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
...@@ -14,8 +17,8 @@ ...@@ -14,8 +17,8 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ADataType = F16; using ADataType = F32;
using BDataType = F16; using BDataType = F32;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance =
...@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance = ...@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance =
2, // NumDim_m, {N, C} 2, // NumDim_m, {N, C}
2, // NumDim_n, {H, W} 2, // NumDim_n, {H, W}
1, // NumDim_k, {D} 1, // NumDim_k, {D}
8, // MPerThread 4, // MPerThread
8, // NPerThread 4, // NPerThread
8, // KPerThread 4, // KPerThread
ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<4>, // InScalarPerVectorSeq
ck::Sequence<4>>; // OutScalarPerVectorSeq ck::Sequence<4>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor> template <typename HostTensorA, typename HostTensorB, typename Functor>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
#include <random>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
...@@ -48,10 +52,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -48,10 +52,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for(std::size_t n = 0; n < N; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
ADataType tmp_val; ADataType tmp_val;
// auto a_val = A_nchw(n, c, h, w);
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor_b(tmp_val, a_val); functor_b(tmp_val, a_val);
// functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
scale * tmp_val); scale * tmp_val);
} }
...@@ -62,12 +64,14 @@ int main() ...@@ -62,12 +64,14 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
std::vector<std::size_t> nchw = {4, 2, 1, 8}; std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {4, 1, 8, 2}; std::vector<std::size_t> nhwc = {16, 32, 64, 8};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
float scale = 1.f; float scale = 1.f;
auto i = 0; auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
...@@ -75,7 +79,7 @@ int main() ...@@ -75,7 +79,7 @@ int main()
{ {
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i; (h * nchw[3]) + w] = i;
i++; i = dis(gen);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment