Commit 175a17f8 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Merge branch 'gfx950' of https://github.com/ROCm/composable_kernel-internal into lwpck-2390

parents 3e520bbd 1504c3e8
...@@ -183,13 +183,24 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") ...@@ -183,13 +183,24 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
message("Enabling XDL instances") message("Enabling XDL instances")
add_definitions(-DCK_USE_XDL) add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON") endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
message("Enabling FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94)
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances") message("Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_OCP_FP8)
set(CK_USE_OCP_FP8 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
...@@ -221,7 +232,7 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) ...@@ -221,7 +232,7 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
endif() endif()
set(check-coerce) set(check-coerce)
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132 AND ${hip_VERSION_FLAT} LESS 600300000) if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
message("Adding the amdgpu-coerce-illegal-types=1") message("Adding the amdgpu-coerce-illegal-types=1")
add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1") add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1")
endif() endif()
......
...@@ -24,10 +24,10 @@ RUN if [ "$ROCMVERSION" != "6.3" ]; then \ ...@@ -24,10 +24,10 @@ RUN if [ "$ROCMVERSION" != "6.3" ]; then \
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \
elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \ elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3.0.1-20.04-1_all.deb --no-check-certificate" && \ sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3-20.04-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3.0.1-20.04-1_all.deb && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3-20.04-1_all.deb && \
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3.0.1 rel-5 > /etc/apt/sources.list.d/rocm-build.list' && \ sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3 rel-20 > /etc/apt/sources.list.d/rocm-build.list' && \
amdgpu-repo --amdgpu-build=2033700; \ amdgpu-repo --amdgpu-build=2074281; \
fi fi
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
......
...@@ -56,6 +56,14 @@ if (GPU_TARGETS) ...@@ -56,6 +56,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON") set(CK_USE_WMMA "ON")
endif() endif()
if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_OCP_FP8)
set(CK_USE_OCP_FP8 "ON")
endif()
if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94")
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
else() else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON") set(CK_USE_XDL "ON")
......
rocm-docs-core==1.8.3 rocm-docs-core==1.8.4
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.8.3 rocm-docs-core==1.8.4
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
...@@ -76,7 +76,7 @@ struct ProblemSizeSplitK final ...@@ -76,7 +76,7 @@ struct ProblemSizeSplitK final
struct ExecutionConfig final struct ExecutionConfig final
{ {
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU // 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
int do_verification = 3; int do_verification = 1;
int init_method = 2; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
}; };
......
...@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
switch(config.init_method) switch(config.init_method)
{ {
case 0: case 0:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k); ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n); ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
break; break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
......
...@@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
......
...@@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout; ...@@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout;
struct ExecutionConfig final struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
}; };
......
...@@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
Tensor<EDataType> conv_output_device(conv_output_g_n_k_wos_desc); Tensor<EDataType> conv_output_device(conv_output_g_n_k_wos_desc);
Tensor<R0DataType> r0_device(r0_desc); Tensor<R0DataType> r0_device(r0_desc);
std::cout << "input: " << conv_input.mDesc << std::endl;
std::cout << "weight: " << conv_weight.mDesc << std::endl;
std::cout << "output: " << conv_output_device.mDesc << std::endl;
std::cout << "reduction: " << r0_device.mDesc << std::endl << std::endl;
switch(config.init_method) switch(config.init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-1, 1}(conv_weight);
break;
case 2:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
ck::utils::FillUniformDistribution<BDataType>{-1, 1}(conv_weight);
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input); ck::utils::FillUniformDistribution<ADataType>{-8, 7}(conv_input);
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight); ck::utils::FillUniformDistribution<BDataType>{-1, 1}(conv_weight);
} }
DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize()); DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize());
...@@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
return false; return false;
} }
// XXX: DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle will not initialize r0.
r0_device_buf.SetValue(ck::NumericLimits<R0DataType>::Lowest());
const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
if(config.time_kernel)
{
const std::size_t flop = problem_size.GetFlops(); const std::size_t flop = problem_size.GetFlops();
const std::size_t num_btype = problem_size.GetByte<ADataType, BDataType, EDataType>(); const std::size_t num_btype = problem_size.GetByte<ADataType, BDataType, EDataType>();
const float tflops = static_cast<float>(flop) / 1.E9 / avg_time; const float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
const float gb_per_sec = num_btype / 1.E6 / avg_time; const float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< conv.GetTypeString() << std::endl; << " GB/s, " << conv.GetTypeString() << std::endl;
}
else
{
std::cout << "FINISHED: " << conv.GetTypeString() << std::endl;
}
if(config.do_verification) if(config.do_verification)
{ {
...@@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
BElementOp{}, BElementOp{},
PassThrough{}); PassThrough{});
std::cout << "\nRunning verification on CPU." << std::endl;
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
Tensor<R0DataType> r0_host(r0_device.mDesc); Tensor<R0DataType> r0_host(r0_device.mDesc);
...@@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
conv_output_device_buf.FromDevice(conv_output_device.mData.data()); conv_output_device_buf.FromDevice(conv_output_device.mData.data());
r0_device_buf.FromDevice(r0_device.mData.data()); r0_device_buf.FromDevice(r0_device.mData.data());
return ck::utils::check_err(conv_output_device, auto pass = ck::utils::check_err(conv_output_device,
conv_output_host, conv_output_host,
"Error: incorrect results! (Matrix E)", "Error: incorrect results! (Matrix E)",
1e-5f, 1e-3f,
1e-4f) && 1e-3f);
ck::utils::check_err( pass =
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f); pass && ck::utils::check_err(
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f);
if(pass)
std::cout << "Verification on CPU: PASS" << std::endl;
return pass;
} }
return true; return true;
......
...@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDMatrices; ++j) for(int j = 0; j < NumDMatrices; ++j)
{ {
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
} }
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
for(int j = 0; j < NumDMatrices; ++j) for(int j = 0; j < NumDMatrices; ++j)
{ {
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
} }
} }
} }
......
...@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDs; ++j) for(int j = 0; j < NumDs; ++j)
{ {
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
} }
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
for(int j = 0; j < NumDs; ++j) for(int j = 0; j < NumDs; ++j)
{ {
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
} }
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
} }
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<D0DataType, 1>{});
} }
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
struct ProblemSize final struct ProblemSize final
...@@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
} }
} }
......
...@@ -198,7 +198,7 @@ int main() ...@@ -198,7 +198,7 @@ int main()
throw std::runtime_error("wrong! this device_op instance does not support this problem"); throw std::runtime_error("wrong! this device_op instance does not support this problem");
} }
// init reducetion buffer to 0 // init reduction buffer to 0
r0_device_buf.SetZero(); r0_device_buf.SetZero();
r1_device_buf.SetZero(); r1_device_buf.SetZero();
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -175,8 +175,8 @@ int main(int argc, char* argv[]) ...@@ -175,8 +175,8 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
} }
c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5}); c0_n_bias.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5});
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) ...@@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
...@@ -157,7 +157,7 @@ int run(int argc, char* argv[]) ...@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment