Unverified Commit fbd9d357 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #68 from ROCm/merge_from_public

Merge from public
parents 395b155a 22593e25
......@@ -23,7 +23,7 @@ endif()
set(version 1.1.0)
# Check support for CUDA/HIP in Cmake
project(composable_kernel VERSION ${version} LANGUAGES CXX)
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
include(CTest)
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
......@@ -112,7 +112,7 @@ message("checking which targets are supported")
#Setting GPU_TARGETS on command line will override this list
if(NOT PROFILER_ONLY)
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else()
add_definitions(-DPROFILER_ONLY)
set(GPU_TARGETS "" CACHE STRING "" FORCE)
......@@ -137,12 +137,10 @@ endif()
message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}")
set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
if(GPU_TARGETS)
message("Building CK for the following targets: ${GPU_TARGETS}")
else()
message("Building CK for the following targets: ${AMDGPU_TARGETS}")
message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}")
endif()
if (GPU_TARGETS)
......@@ -227,7 +225,13 @@ link_libraries(Threads::Threads)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
## HIP
set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
set(CMAKE_HIP_EXTENSIONS ON)
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
## OpenMP
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
......
......@@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
endif()
\ No newline at end of file
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp)
target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
endif()
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetFlops(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
ck::index_t G = weights_lengths[0];
ck::index_t N = output_lengths[1];
ck::index_t K = weights_lengths[1];
ck::index_t C = weights_lengths[2];
return static_cast<std::size_t>(2) * G * N * K * C *
std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename InDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetInputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& input_lengths)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * std::accumulate(std::begin(input_lengths),
std::end(input_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename WeiDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetWeightByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename OutDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetOutputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>());
}
template <ck::index_t NumDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
ck::index_t NumNonSpatialDim = 3,
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
{
std::size_t in_mem_size = GetInputByte<InDataType, NumDimSpatial>(in_lengths);
std::size_t wei_mem_size = GetWeightByte<WeiDataType, NumDimSpatial>(wei_lengths);
std::size_t out_mem_size = GetOutputByte<OutDataType, NumDimSpatial>(out_lengths);
SimpleDeviceMem in(in_mem_size);
SimpleDeviceMem wei(wei_mem_size);
SimpleDeviceMem out(out_mem_size);
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_strides;
in_strides.fill(0);
wei_strides.fill(0);
out_strides.fill(0);
in_strides.back() = 1;
wei_strides.back() = 1;
out_strides.back() = 1;
std::partial_sum(rbegin(in_lengths),
std::prev(rend(in_lengths)),
std::next(rbegin(in_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(wei_lengths),
std::prev(rend(wei_lengths)),
std::next(rbegin(wei_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(out_lengths),
std::prev(rend(out_lengths)),
std::next(rbegin(out_strides)),
std::multiplies<>{});
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths));
std::rotate(rbegin(in_lengths),
std::next(rbegin(in_lengths)),
std::next(rbegin(in_lengths), NumDimSpatial + 1));
std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides));
std::rotate(rbegin(in_strides),
std::next(rbegin(in_strides)),
std::next(rbegin(in_strides), NumDimSpatial + 1));
std::rotate(rbegin(wei_lengths),
std::next(rbegin(wei_lengths)),
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
std::rotate(rbegin(wei_strides),
std::next(rbegin(wei_strides)),
std::next(rbegin(wei_strides), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths));
std::rotate(rbegin(out_lengths),
std::next(rbegin(out_lengths)),
std::next(rbegin(out_lengths), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides));
std::rotate(rbegin(out_strides),
std::next(rbegin(out_strides)),
std::next(rbegin(out_strides), NumDimSpatial + 1));
std::array<ck::index_t, NumDimSpatial> conv_filter_strides;
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations;
std::array<ck::index_t, NumDimSpatial> input_left_pads;
std::array<ck::index_t, NumDimSpatial> input_right_pads;
conv_filter_strides.fill(1);
conv_filter_dilations.fill(1);
input_left_pads.fill(1);
input_right_pads.fill(1);
std::size_t flop = GetFlops<NumDimSpatial>(out_lengths, wei_lengths);
std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size;
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough,
AComputeType,
BComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return false;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
......@@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Wo = 28;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main()
{
std::array<ck::index_t, NumDimSpatial + 3> in_lengths{G, N, Wi, C};
std::array<ck::index_t, NumDimSpatial + 3> in_strides{0, 0, 0, 1};
std::array<ck::index_t, NumDimSpatial + 3> wei_lengths{G, K, X, C};
std::array<ck::index_t, NumDimSpatial + 3> wei_strides{0, 0, 0, 1};
std::array<ck::index_t, NumDimSpatial + 3> out_lengths{G, N, Wo, K};
std::array<ck::index_t, NumDimSpatial + 3> out_strides{0, 0, 0, 1};
std::partial_sum(rbegin(in_lengths),
std::prev(rend(in_lengths)),
std::next(rbegin(in_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(wei_lengths),
std::prev(rend(wei_lengths)),
std::next(rbegin(wei_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(out_lengths),
std::prev(rend(out_lengths)),
std::next(rbegin(out_strides)),
std::multiplies<>{});
// transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW
std::rotate(rbegin(in_lengths),
std::next(rbegin(in_lengths)),
std::next(rbegin(in_lengths), NumDimSpatial + 1));
std::rotate(rbegin(in_strides),
std::next(rbegin(in_strides)),
std::next(rbegin(in_strides), NumDimSpatial + 1));
std::rotate(rbegin(wei_lengths),
std::next(rbegin(wei_lengths)),
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
std::rotate(rbegin(wei_strides),
std::next(rbegin(wei_strides)),
std::next(rbegin(wei_strides), NumDimSpatial + 1));
std::rotate(rbegin(out_lengths),
std::next(rbegin(out_lengths)),
std::next(rbegin(out_lengths), NumDimSpatial + 1));
std::rotate(rbegin(out_strides),
std::next(rbegin(out_strides)),
std::next(rbegin(out_strides), NumDimSpatial + 1));
std::array<ck::index_t, NumDimSpatial> filter_strides{1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1};
std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X;
std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C +
sizeof(WeiDataType) * G * K * X * C +
sizeof(OutDataType) * G * N * Wo * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return EXIT_FAILURE;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return run_grouped_conv_fwd<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3>({N, Wi, G, C}, {G, K, X, C}, {N, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
......@@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W
static constexpr ck::index_t Ho = 28; // output H
static constexpr ck::index_t Wo = 28; // output W
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main()
{
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space
// However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW
// Hence, we need to adjust the order of stride
std::array<ck::index_t, 5> in_lengths{G, N, C, Hi, Wi};
std::array<ck::index_t, 5> in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C};
std::array<ck::index_t, 5> wei_lengths{G, K, C, Y, X};
std::array<ck::index_t, 5> wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C};
std::array<ck::index_t, 5> out_lengths{G, N, K, Ho, Wo};
std::array<ck::index_t, 5> out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C};
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1};
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1};
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C);
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C);
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X;
std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C +
sizeof(WeiDataType) * G * K * Y * X * C +
sizeof(OutDataType) * N * Ho * Wo * G * K;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return EXIT_FAILURE;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
{},
{},
out_lengths,
out_strides,
filter_strides,
filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return run_grouped_conv_fwd<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3>({N, Hi, Wi, G, C}, {G, K, Y, X, C}, {N, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
......@@ -7,22 +7,6 @@ endif()
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp)
target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp)
target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp)
target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
endif()
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
......
rocm-docs-core==1.1.2
rocm-docs-core==1.1.3
sphinxcontrib-bibtex==2.6.2
......@@ -103,7 +103,7 @@ requests==2.31.0
# via
# pygithub
# sphinx
rocm-docs-core==1.1.2
rocm-docs-core==1.1.3
# via -r requirements.in
six==1.16.0
# via
......
......@@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endif()
endforeach()
endif()
if(INSTANCES_ONLY)
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(EX_TARGETS ${GPU_TARGETS})
endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
......@@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
add_dependencies(examples ${EXAMPLE_NAME})
add_dependencies(check ${EXAMPLE_NAME})
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
......@@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endif()
endforeach()
endif()
if(INSTANCES_ONLY)
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(EX_TARGETS ${GPU_TARGETS})
endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
......@@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
add_dependencies(examples ${EXAMPLE_NAME})
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
......
......@@ -104,14 +104,19 @@ inline void flush_icache()
hip_check_error(hipGetLastError());
}
// if TimePrePress == false, return time does not include preprocess's time
template <bool TimePreprocess, typename Args, typename F, typename PreProcessFunc>
template <bool TimePreprocess,
typename GemmArgs,
typename... Args,
typename F,
typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args& args)
GemmArgs& gemm_args,
Args... args)
{
#if CK_TIME_KERNEL
#define MEDIAN 1
......@@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// warm up
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
}
......@@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess();
}
// run real kernel
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
// end real kernel
......@@ -190,9 +195,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
{
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf("args.p_a_grid: %p, args.p_b_grid:%p\n",
static_cast<const void*>(args.p_a_grid),
static_cast<const void*>(args.p_b_grid));
printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
static_cast<const void*>(gemm_args.p_a_grid),
static_cast<const void*>(gemm_args.p_b_grid));
}
}
......@@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
else
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
hip_check_error(hipGetLastError());
return 0;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -1952,7 +1952,7 @@ struct Modulo
}
};
template <typename LowLengths>
template <typename LowLengths, bool ApplyModulo>
struct Xor
{
using LowerIndex = MultiIndex<2>;
......@@ -1981,8 +1981,15 @@ struct Xor
idx_low(Number<0>{}) = idx_up[Number<0>{}];
idx_low(Number<1>{}) =
idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
if constexpr(ApplyModulo)
{
idx_low(Number<1>{}) =
idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
}
else
{
idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}];
}
}
template <typename LowIdxDiff,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -128,9 +128,15 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
return Modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths)
{
return Xor<LowLengths, true /*ApplyModulo*/>{low_lengths};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths)
{
return Xor<LowLengths>{low_lengths};
return Xor<LowLengths, false /*ApplyModulo*/>{low_lengths};
}
} // namespace ck
......@@ -53,8 +53,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
......@@ -14,95 +14,137 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop>
index_t NumBatchToMerge,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_xdlops_bwd_weight(
const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = batch_count;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
compute_ptr_offset_of_batch.GetCPtrOffset(0);
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
index_t NumBatchToMerge,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
p_shared_0,
p_shared_1,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -121,7 +163,7 @@ template <ck::index_t NDimSpatial,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t MPerXdl,
ck::index_t NPerXdl,
......@@ -145,8 +187,11 @@ template <ck::index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t NumBatchToMerge = 1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
......@@ -161,6 +206,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ComputeTypeA,
ComputeTypeB>
{
static_assert(is_same_v<InElementwiseOperation, element_wise::PassThrough>);
static_assert(is_same_v<WeiElementwiseOperation, element_wise::PassThrough>);
static_assert(is_same_v<OutElementwiseOperation, element_wise::PassThrough>);
using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle;
using ADataType = OutDataType;
......@@ -183,101 +232,123 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
static constexpr auto K1Number = Number<K1>{};
static constexpr auto conv_to_gemm_transformer =
static constexpr auto conv_to_gemm_transformer_v2 =
TransformConvBwdWeightToGemmV2<NDimSpatial,
MPerBlock,
NPerBlock,
K1Number,
KPerBlock / K1Number,
NumBatchToMerge,
ConvBackwardWeightSpecialization>{};
static constexpr auto conv_to_gemm_transformer_v1 =
TransformConvBwdWeightToGemm<NDimSpatial,
MPerBlock,
NPerBlock,
K1Number,
K0PerBlock,
KPerBlock / K1Number,
ConvBackwardWeightSpecialization>{};
// Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
// M1 & M0
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
static constexpr auto ABlockLdsM1Padding = 4;
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default;
// N1 & N0
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
static constexpr auto BBlockLdsN1Padding = 4;
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
const ck::index_t dim = 1;
const ck::index_t batch = 1;
const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
const std::array<ck::index_t, NDimSpatial> params{1, 1};
return conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params,
batch);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
const ck::index_t dim = 1;
const ck::index_t batch = 1;
const std::array<ck::index_t, NDimSpatial> lengths{1};
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1};
const std::array<ck::index_t, NDimSpatial> params{1};
return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params,
batch);
const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
return conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params,
batch);
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
static auto GetElementwiseCGridDesc()
{
const ck::index_t dim = 1;
const ck::index_t batch = 1;
const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
const std::array<ck::index_t, NDimSpatial> params{1, 1};
return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params,
batch);
return conv_to_gemm_transformer_v1
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params,
batch)[I2];
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
static auto GetElementwiseCGridDesc()
{
const ck::index_t dim = 1;
const ck::index_t batch = 1;
const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(
dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params,
batch);
return conv_to_gemm_transformer_v1
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
dim,
dim,
lengths,
lengths,
lengths,
strides,
strides,
strides,
params,
params,
params,
params,
batch)[I2];
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
......@@ -285,60 +356,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType,
BDataType,
AccDataType,
AccDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
element_wise::PassThrough,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXdl,
NPerXdl,
K1,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true,
1,
PipelineVersion::v1,
ComputeTypeA,
ComputeTypeB>;
using CElementwiseGridDesc_M_N =
remove_cvref_t<decltype(GetElementwiseCGridDesc<NDimSpatial>())>;
using GridwiseGemm =
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
......@@ -347,8 +414,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseElementwise =
GridwiseElementwise<Tuple<CGridDesc_M_N>,
Tuple<CGridDesc_M_N>,
GridwiseElementwise<Tuple<CElementwiseGridDesc_M_N>,
Tuple<CElementwiseGridDesc_M_N>,
Tuple<const AccDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
......@@ -366,10 +433,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct Argument : public BaseArgument
{
......@@ -395,11 +460,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
: p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid},
p_e_grid_{p_wei_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
ce_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{},
compute_ptr_offset_of_batch_{},
M01_{M01},
N01_{N01},
......@@ -430,7 +494,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin(output_spatial_lengths_));
const auto descs =
conv_to_gemm_transformer
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
......@@ -447,15 +511,34 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads,
k_batch_);
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
ce_grid_desc_m_n_ = descs[I2];
ce_elementwise_grid_desc_m_n_ =
conv_to_gemm_transformer_v1
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
k_batch_)[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_);
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{
ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)};
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
......@@ -465,16 +548,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_,
ce_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(
ce_grid_desc_m_n_);
}
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ce_grid_desc_m_n_,
GridwiseGemm::CalculateMBlock(GemmM),
GridwiseGemm::CalculateNBlock(GemmN));
}
std::size_t GetWorkspaceSizeBytes() const
......@@ -486,12 +564,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const BDataType* p_b_grid_;
EDataType* p_e_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N ce_grid_desc_m_n_;
CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2CTileMap block_2_ctile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
// for computing batch offset
......@@ -525,96 +603,676 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.ce_grid_desc_m_n_,
arg.block_2_ctile_map_))
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
// nullptr for output, will be set after workspace set
typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_,
arg.p_b_grid_,
p_c_grid,
GemmM,
GemmN,
GemmK,
I0,
I0,
I0,
arg.k_batch_};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumBatchToMerge);
float ave_time = 0;
index_t k_grain = gemm_arg.KBatch * KPerBlock;
index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * (KPerBlock);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto num_k_per_block =
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync(
gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
};
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
gemm_arg_,
stream_config.rotating_count,
gemm_arg_.M * gemm_arg_.K * sizeof(ADataType),
gemm_arg_.K * gemm_arg_.N * sizeof(BDataType));
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
clear_workspace();
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
gemm_arg_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
}
else
{
ave_time = launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(gemm_arg.KBatch > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(gemm_arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(gemm_arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(gemm_arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(gemm_arg.KBatch > 1)
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
constexpr bool has_main_loop = has_main_k_block_loop.value;
auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
};
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType,
BDataType,
AccDataType,
OutElementwiseOperation,
InElementwiseOperation,
element_wise::PassThrough,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
return launch_and_time_kernel_with_preprocess(
stream_config,
preprocess,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
p_c_grid,
arg.a_element_op_,
arg.b_element_op_,
element_wise::PassThrough{},
arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
};
return ave_time;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto launch_elementwise_kernel = [&]() {
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size =
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
arg.Conv_G_;
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
arg.ce_elementwise_grid_desc_m_n_) *
arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = {
arg.compute_ptr_offset_of_batch_.BatchStrideC_};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
ck::Tuple<CGridDesc_M_N>,
ck::Tuple<CGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
......@@ -627,8 +1285,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
dim3(grid_size),
dim3(BlockSize),
0,
make_tuple(arg.ce_grid_desc_m_n_),
make_tuple(arg.ce_grid_desc_m_n_),
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(arg.ce_elementwise_grid_desc_m_n_),
make_tuple(p_c_grid),
make_tuple(arg.p_e_grid_),
arg.elementwise_block_2_ctile_map_,
......@@ -638,16 +1296,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
in_out_batch_strides);
};
float avg_time = 0;
if(has_main_k0_block_loop)
{
avg_time = launch_gemm_kernel(integral_constant<bool, true>{});
}
else
{
avg_time = launch_gemm_kernel(integral_constant<bool, false>{});
}
float avg_time = RunGemmV3(arg, stream_config);
avg_time += launch_elementwise_kernel();
return avg_time;
}
......@@ -667,6 +1316,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
typename GridwiseGemm::Argument gemm_arg{
nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1);
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
{
if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages)
{
return false;
}
}
// Check this here, it allows to use other instances from factory even
// if workspace is not allocated
if(!arg.p_workspace_)
......@@ -723,10 +1389,38 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
if constexpr(NumBatchToMerge > 1)
{
// support only if whole M and N can be proccessed on one block
if(!(GemmM <= MPerBlock && GemmN <= NPerBlock))
{
return false;
}
if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1))
{
return false;
}
if(arg.Conv_G_ % NumBatchToMerge != 0)
{
return false;
}
}
if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0))
{
if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1))
{
return false;
}
if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1))
{
return false;
}
}
// vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1))
{
return false;
}
......@@ -737,11 +1431,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return false;
}
// Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.ce_grid_desc_m_n_,
arg.block_2_ctile_map_);
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -840,13 +1530,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< KPerBlock << ", "
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
<< K1 << ", "
<< MXdlPerWave << ", "
......@@ -857,7 +1558,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< CBlockTransferScalarPerVector_NWaveNPerXdl << ", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumBatchToMerge
<< ">";
// clang-format on
......
......@@ -45,8 +45,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t KBatch = 1;
......
......@@ -79,7 +79,7 @@ __global__ void
ignore = b_element_op;
ignore = c_element_op;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx11__))
}
// Assume B is Col-Major
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment