Commit e24f37fb authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into add_int8_wmma_example_instance
parents 9697ad4e 85e2e1e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
......@@ -100,6 +100,10 @@ int main(int argc, char* argv[])
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
......@@ -153,6 +157,10 @@ int main(int argc, char* argv[])
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
......@@ -53,12 +53,35 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
using DeviceOp = ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
using DeviceOp = ck::tensor_operation::device::DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
alpha,
beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
{
throw std::runtime_error(
"The generic kernel instance should be able to support any input shapes");
};
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
......@@ -74,11 +97,6 @@ int main(int argc, char* argv[])
{
auto& op_ptr = op_ptrs[i];
if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim)
{
continue;
}
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
......
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations)
......@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
add_executable(client_gemm_quantization gemm_quantization.cpp)
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
......@@ -32,63 +32,49 @@ struct SimpleDeviceMem
};
template <ck::index_t NumDimSpatial>
std::size_t GetFlops(ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
std::size_t GetFlops(const std::array<ck::index_t, NumDimSpatial>& output_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
{
constexpr ck::index_t spatial_offset = 3;
const auto C = filter_lengths[2];
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G * N * K * C *
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
return static_cast<std::size_t>(2) * C *
std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
std::accumulate(std::begin(filter_lengths) + spatial_offset,
std::end(filter_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename InDataType, ck::index_t NumDimSpatial>
std::size_t GetInputByte(ck::index_t G,
ck::index_t N,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths)
std::size_t GetInputByte(const std::array<ck::index_t, NumDimSpatial>& input_lengths)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * (G * N * C *
std::accumulate(std::begin(input_spatial_lengths),
std::end(input_spatial_lengths),
return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths),
std::end(input_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()));
}
template <typename WeiDataType, ck::index_t NumDimSpatial>
std::size_t GetWeightByte(ck::index_t G,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
std::size_t GetWeightByte(const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * (G * K * C *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths),
std::end(filter_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()));
}
template <typename OutDataType, ck::index_t NumDimSpatial>
std::size_t GetOutputByte(ck::index_t G,
ck::index_t N,
ck::index_t K,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths)
std::size_t GetOutputByte(const std::array<ck::index_t, NumDimSpatial>& output_lengths)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G * N * K *
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
......@@ -101,13 +87,12 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout,
typename OutLayout>
bool run_grouped_conv_bwd_weight(
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& filter_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NumDimSpatial>& input_left_pads,
......@@ -115,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
{
ck::index_t split_k = 2;
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths));
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths));
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths));
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths));
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths));
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths));
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout,
......@@ -141,6 +126,10 @@ bool run_grouped_conv_bwd_weight(
float best_gb_per_sec = 0;
float best_tflops = 0;
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NumDimSpatial + 3> b_g_k_c_xs_lengths{};
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
......@@ -150,13 +139,12 @@ bool run_grouped_conv_bwd_weight(
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
out.GetDeviceBuffer(),
G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -172,12 +160,10 @@ bool run_grouped_conv_bwd_weight(
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop =
GetFlops<NumDimSpatial>(G, N, K, C, output_spatial_lengths, filter_spatial_lengths);
std::size_t num_bytes =
GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths) +
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths) +
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths);
std::size_t flop = GetFlops<NumDimSpatial + 3>(output_lengths, filter_lengths);
std::size_t num_bytes = GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths) +
GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths) +
GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
......@@ -217,13 +203,12 @@ bool run_grouped_conv_bwd_weight(
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
out.GetDeviceBuffer(),
G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment