Commit 2f463a94 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl

parents ca8b5c79 ac9e01e2
add_custom_target(example_convnd_fwd_reduce_xdl) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(example_convnd_fwd_reduce_xdl)
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) if(USE_BITINT_EXTENSION_INT4)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
endif()
\ No newline at end of file
...@@ -17,115 +17,11 @@ ...@@ -17,115 +17,11 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
static void pool_host_verify(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::array<ck::index_t, 2>& window_spatial_lengths,
const std::array<ck::index_t, 2>& window_strides,
const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/)
{
const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
{
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x;
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
};
}
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InLayout, typename InLayout,
typename OutLayout, typename OutLayout,
...@@ -152,7 +48,8 @@ bool pool_test(bool do_verification, ...@@ -152,7 +48,8 @@ bool pool_test(bool do_verification,
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
OutDataType, // OutDataType OutDataType, // OutDataType
AccDataType, // AccDataType IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType
ReduceOpId, ReduceOpId,
OutputIndex, OutputIndex,
64, // BlockSize 64, // BlockSize
...@@ -165,10 +62,10 @@ bool pool_test(bool do_verification, ...@@ -165,10 +62,10 @@ bool pool_test(bool do_verification,
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
const std::array<ck::index_t, 2> window_spatial_lengths{{Y, X}}; const std::vector<ck::index_t> window_spatial_lengths{Y, X};
const std::array<ck::index_t, 2> window_strides{{window_stride_h, window_stride_w}}; const std::vector<ck::index_t> window_strides{window_stride_h, window_stride_w};
const std::array<ck::index_t, 2> input_left_pads{{in_left_pad_h, in_left_pad_w}}; const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w};
const std::array<ck::index_t, 2> input_right_pads{{in_right_pad_h, in_right_pad_w}}; const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w};
// tensor layout // tensor layout
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -219,14 +116,16 @@ bool pool_test(bool do_verification, ...@@ -219,14 +116,16 @@ bool pool_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
N, {N, C, Hi, Wi},
C, {Y, X},
std::array<ck::index_t, 2>{{Hi, Wi}}, {N, C, Ho, Wo},
std::array<ck::index_t, 2>{{Y, X}}, {C * Hi * Wi, 1, Wi * C, C},
std::array<ck::index_t, 2>{{Ho, Wo}}, {C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads,
{2, 3});
if(!pool.IsSupportedArgument(argument_ptr.get())) if(!pool.IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -252,13 +151,20 @@ bool pool_test(bool do_verification, ...@@ -252,13 +151,20 @@ bool pool_test(bool do_verification,
if(do_verification) if(do_verification)
{ {
pool_host_verify<InDataType, using ReferencePoolingFwdInstance =
ck::tensor_operation::host::ReferencePoolingFwd<4,
2,
InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
ReduceOpId, ReduceOpId,
PropagateNan, PropagateNan,
OutputIndex>(in_n_c_hi_wi, OutputIndex>;
auto ref_pooling = ReferencePoolingFwdInstance{};
auto ref_pooling_invoker = ref_pooling.MakeInvoker();
auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_hi_wi,
out_n_c_ho_wo_host, out_n_c_ho_wo_host,
out_indices_n_c_ho_wo_host, out_indices_n_c_ho_wo_host,
window_spatial_lengths, window_spatial_lengths,
...@@ -266,6 +172,8 @@ bool pool_test(bool do_verification, ...@@ -266,6 +172,8 @@ bool pool_test(bool do_verification,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
ref_pooling_invoker.Run(ref_pooling_argument);
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host); pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host);
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
using InDataType = ck::half_t; using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -91,7 +90,7 @@ int main(int argc, char* argv[]) ...@@ -91,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType, bool pass = pool_test<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
InLayout, InLayout,
OutLayout, OutLayout,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
...@@ -12,7 +11,7 @@ ...@@ -12,7 +11,7 @@
using InDataType = float; using InDataType = float;
using OutDataType = float; using OutDataType = float;
using AccDataType = float; using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -91,7 +90,7 @@ int main(int argc, char* argv[]) ...@@ -91,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType, bool pass = pool_test<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
InLayout, InLayout,
OutLayout, OutLayout,
......
...@@ -2,5 +2,7 @@ ...@@ -2,5 +2,7 @@
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
# xdlops # xdlops
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
\ No newline at end of file add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
endif()
\ No newline at end of file
add_custom_target(example_gemm_reduce_xdl) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl)
add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_reduce_xdl_max)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl) add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
add_dependencies(example_gemm_reduce_xdl_max add_dependencies(example_gemm_reduce_xdl_max
example_gemm_max_xdl_bf16 example_gemm_max_xdl_bf16
example_gemm_max_xdl_fp16 example_gemm_max_xdl_fp16
example_gemm_max_xdl_fp32 example_gemm_max_xdl_fp32
example_gemm_max_xdl_int8) example_gemm_max_xdl_int8)
add_dependencies(example_gemm_reduce_xdl_mean_meansquare add_dependencies(example_gemm_reduce_xdl_mean_meansquare
example_gemm_mean_meansquare_xdl_fp16 example_gemm_mean_meansquare_xdl_fp16
example_gemm_mean_meansquare_xdl_fp32 example_gemm_mean_meansquare_xdl_fp32
example_gemm_mean_meansquare_xdl_bf16 example_gemm_mean_meansquare_xdl_bf16
example_gemm_add_addsquare_xdl_int8) example_gemm_add_addsquare_xdl_int8)
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
add_dependencies(example_gemm_reduce_xdl add_dependencies(example_gemm_reduce_xdl
example_gemm_reduce_xdl_mean_meansquare example_gemm_reduce_xdl_mean_meansquare
example_gemm_reduce_xdl_max example_gemm_reduce_xdl_max
example_gemm_add_add_mean_meansquare_xdl) example_gemm_add_add_mean_meansquare_xdl)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
endif()
endif() endif()
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
endif()
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp)
endif()
add_custom_target(example_grouped_conv_bwd_weight) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(example_grouped_conv_bwd_weight)
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
example_grouped_conv_bwd_weight_xdl_bf16) example_grouped_conv_bwd_weight_xdl_bf16)
endif()
add_custom_target(example_grouped_conv_bwd_weight_dl) add_custom_target(example_grouped_conv_bwd_weight_dl)
......
...@@ -18,7 +18,9 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -18,7 +18,9 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
// Set split_k = 2 for xdl op, split_k = 1 for dl // Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1 // Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support // TODO: Add Dl op split_k > 1 support
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102"))
{ {
split_k = 2; split_k = 2;
} }
......
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp)
add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp)
endif()
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: ...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -385,7 +251,8 @@ int main(int argc, char* argv[]) ...@@ -385,7 +251,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -393,14 +260,13 @@ int main(int argc, char* argv[]) ...@@ -393,14 +260,13 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: ...@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKNN; using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -385,7 +251,8 @@ int main(int argc, char* argv[]) ...@@ -385,7 +251,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -393,14 +260,13 @@ int main(int argc, char* argv[]) ...@@ -393,14 +260,13 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument =
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: ...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -368,7 +234,8 @@ int main(int argc, char* argv[]) ...@@ -368,7 +234,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -376,14 +243,14 @@ int main(int argc, char* argv[]) ...@@ -376,14 +243,14 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: ...@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_ms_ks_{a_ms_ks},
b_ns_ks_{b_ns_ks},
e_ms_ns_{e_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_ms_ks_;
const Tensor<BDataType>& b_ns_ks_;
Tensor<EDataType>& e_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_M2_N2_K2::Argument;
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
for(int k1 = 0; k1 < K1; ++k1)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b;
}
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
const Tensor<BDataType>& b_ns_ks,
Tensor<EDataType>& e_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M2_N2_K2"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -368,7 +234,8 @@ int main(int argc, char* argv[]) ...@@ -368,7 +234,8 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -376,14 +243,14 @@ int main(int argc, char* argv[]) ...@@ -376,14 +243,14 @@ int main(int argc, char* argv[])
CShuffleDataType, CShuffleDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp>;
PassThrough>;
auto ref_gemm = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument( Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
add_custom_target(example_grouped_conv_fwd_multiple_d) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(example_grouped_conv_fwd_multiple_d)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
endif() # USE_BITINT_EXTENSION_INT4 endif() # USE_BITINT_EXTENSION_INT4
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
endif()
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
endif() endif()
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
if(NOT GPU_TARGETS MATCHES "gfx940") add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
if(NOT GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
endif() endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4) endif(USE_BITINT_EXTENSION_INT4)
endif()
\ No newline at end of file
add_custom_target(example_splitK_gemm_xdl) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(example_splitK_gemm_xdl)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) add_dependencies(example_splitK_gemm_xdl
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_dependencies(example_splitK_gemm_xdl
example_splitK_gemm_xdl_fp32 example_splitK_gemm_xdl_fp32
example_splitK_gemm_xdl_fp16 example_splitK_gemm_xdl_fp16
example_splitK_gemm_xdl_bfp16 example_splitK_gemm_xdl_bfp16
example_splitK_gemm_xdl_int8) example_splitK_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()
endif() endif()
add_custom_target(example_grouped_conv_bwd_data) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(example_grouped_conv_bwd_data)
add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp)
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp)
add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16)
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
endif()
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) \ No newline at end of file
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
endif()
# Conv perlayer quantization # Conv perlayer quantization
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp)
# Conv perchannel quantization # Conv perchannel quantization
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp)
# Conv + bias + relu perlayer quantization # Conv + bias + relu perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp)
# Conv + bias + relu perchannel quantization # Conv + bias + relu perchannel quantization
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp)
# Conv + bias + tanh perlayer quantization # Conv + bias + tanh perlayer quantization
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment