Commit 5dee1b64 authored by Ville Pietilä's avatar Ville Pietilä
Browse files

Merge remote-tracking branch 'origin/develop' into vpietila/ggemm-profiling

parents 870c3a76 355893cd
[Back to the main page](../README.md)
# Composable Kernel examples
\ No newline at end of file
...@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -119,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -119,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
# Grouped CShuffle GEMM
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`
## example
```
args:
-a_layout Tensor A layout (default:R)
-b_layout Tensor B layout (default:R)
-c_layout Tensor C layout (default:R)
-v 0. No validation, 1. Validation on CPU
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
#include "utils.hpp"
namespace {
struct GroupedGemmKernelParam
{
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = false;
static const bool kTilePermute = false;
static const ck_tile::index_t kOutputRank = 2;
static const int kBlockPerCu = 1;
static const ck_tile::index_t M_Tile = 128;
static const ck_tile::index_t N_Tile = 128;
static const ck_tile::index_t K_Tile = 32;
static const ck_tile::index_t M_Warp = 2;
static const ck_tile::index_t N_Warp = 2;
static const ck_tile::index_t K_Warp = 1;
static const ck_tile::index_t M_Warp_Tile = 32;
static const ck_tile::index_t N_Warp_Tile = 32;
static const ck_tile::index_t K_Warp_Tile = 8;
};
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemmKernelParam::M_Tile,
GroupedGemmKernelParam::N_Tile,
GroupedGemmKernelParam::K_Tile>,
ck_tile::sequence<GroupedGemmKernelParam::M_Warp,
GroupedGemmKernelParam::N_Warp,
GroupedGemmKernelParam::K_Warp>,
ck_tile::sequence<GroupedGemmKernelParam::M_Warp_Tile,
GroupedGemmKernelParam::N_Warp_Tile,
GroupedGemmKernelParam::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
template <typename CLayout>
using GemmEpilogue = std::conditional_t<
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
GroupedGemmKernelParam::kTilePermute,
GroupedGemmKernelParam::kOutputRank,
1,
0,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType,
CDataType,
GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN>>>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
GroupedGemmKernelParam::kPadN,
GroupedGemmKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>,
CodegenGemmPolicy>;
template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
GemmEpilogue<CLayout>>;
}; // namespace
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
}
template <typename ALayout, typename BLayout, typename CLayout>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_)
{
using GroupedGemmKernel = ::Kernel<ALayout, BLayout, CLayout>;
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(
p_workspace_,
arguments.data(),
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GroupedGemmKernelParam::kBlockPerCu>(
GroupedGemmKernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
gemm_descs.size()));
return ave_time;
}
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
template <typename DataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using CDataType = ck_tile::half_t;
using AccDataType = float;
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("validate", "1", "0. No validation, 1. Validation on CPU")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("group_count", "16", "group count");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs);
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
{
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
std::string op_name{"Grouped Gemm"};
std::size_t flop = 0, num_btype = 0;
for(int j = 0; j < group_count; ++j)
{
flop += std::size_t(2) * args[j].M * args[j].N * args[j].K;
num_btype += sizeof(ADataType) * args[j].M * args[j].K +
sizeof(BDataType) * args[j].K * args[j].N +
sizeof(CDataType) * args[j].M * args[j].N;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
int run_grouped_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms;
std::vector<ck_tile::index_t> Ns;
std::vector<ck_tile::index_t> Ks;
std::vector<ck_tile::index_t> stride_As;
std::vector<ck_tile::index_t> stride_Bs;
std::vector<ck_tile::index_t> stride_Cs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(128 + 64 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
c_m_n_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
a_m_k_tensors.push_back(
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)));
b_k_n_tensors.push_back(
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
invoke_gemm<ALayout, BLayout, CLayout>(warmup, repeat, group_count, gemm_descs);
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
}
bool pass{true};
if(arg_parser.get_int("validate"))
{
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
}
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename TLayout>
constexpr auto
f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <typename TLayout>
constexpr auto
f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
...@@ -16,3 +16,4 @@ add_subdirectory(13_moe_sorting) ...@@ -16,3 +16,4 @@ add_subdirectory(13_moe_sorting)
add_subdirectory(14_moe_smoothquant) add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe) add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm) add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
[Back to the main page](../../README.md)
# Composable Kernel supported operations
## Supported device operations
* [Average pooling]()
* [Batched contraction]()
* [Batched gemm]()
* [Batchnorm]()
* [CGEMM]()
* [Contraction]()
* [Convolution]()
* [Image to Column and Column to Image]()
* [Elementwise]()
* [GEMM]()
* [Max pooling]()
* [Reduce]()
* [Normalization]()
* [Permute]()
* [Put]()
* [Softmax]()
...@@ -326,7 +326,7 @@ struct Tensor ...@@ -326,7 +326,7 @@ struct Tensor
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
void SetZero() { ck::ranges::fill<T>(mData, 0); } void SetZero() { ck::ranges::fill<T>(mData, T{0}); }
template <typename F> template <typename F>
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank) void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t> ...@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
float value = 1.0; float value = 1.0;
template <typename... Is> template <typename... Is>
ck::bhalf_t operator()(Is...) ck::half_t operator()(Is...)
{ {
return ck::type_convert<ck::half_t>(value); return ck::type_convert<ck::half_t>(value);
} }
...@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t> ...@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
float value = 1.0; float value = 1.0;
template <typename... Is> template <typename... Is>
ck::bhalf_t operator()(Is...) ck::f8_t operator()(Is...)
{ {
return ck::type_convert<ck::f8_t>(value); return ck::type_convert<ck::f8_t>(value);
} }
...@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard ...@@ -256,14 +256,33 @@ struct GeneratorTensor_Checkboard
} }
}; };
template <ck::index_t Dim> /**
* @brief Is used to generate sequential values based on the specified dimension.
*
* @tparam T The type of the tensor values.
* @tparam Dim The specific dimension used for generation.
*
* GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor:
*
* 0 1 2
* 0 1 2
* 0 1 2
*
* Essentially, the values generated are logical coordinates of the generated element that
* correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column
* indices.
*
*/
template <typename T, ck::index_t Dim>
struct GeneratorTensor_Sequential struct GeneratorTensor_Sequential
{ {
template <typename... Ts> template <typename... Ts>
float operator()(Ts... Xs) const T operator()(Ts... Xs) const
{ {
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}}; std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
return dims[Dim];
float tmp = dims[Dim];
return ck::type_convert<T>(tmp);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1; static constexpr index_t KPerBlock = K0PerBlock * K1;
static constexpr auto transform_conv_to_gemm = using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
TransformConvBwdDataToGemm_v1<NDimSpatial, ConvBackwardDataSpecialization,
ConvBackwardDataSpecialization, K1,
K1, K1,
K1, MPerBlock,
MPerBlock, NPerBlock,
NPerBlock, KPerBlock,
KPerBlock, true /* DoPadGemmM */,
true /* DoPadGemmM */, true /* DoPadGemmN */,
true /* DoPadGemmN */>{}; ALayout,
BLayout,
static auto GetDummyABDsEGridDescriptor() ELayout>;
{
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
},
Number<NumDTensor>{});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_tensor_lengths,
dummy_tensor_strides,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths,
dummy_spatial_lengths);
static auto
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto ds_grid_desc_m_n =
generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
Number<NumDTensor>{});
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple( return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
} }
// desc // desc
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>; using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>; using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
...@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_lengths, /*ds_g_n_c_wis_lengths*/,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_c_wis_strides, ds_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
...@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads} input_right_pads_{input_right_pads}
{ {
...@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
tildes = {i_ztilde, i_ytilde, i_xtilde}; tildes = {i_ztilde, i_ytilde, i_xtilde};
} }
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>( conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>( conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
DsGridDesc_M_N ds_grid_desc_m_n; DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc // populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
static_assert(is_same_v<DLayout, ELayout>);
ds_grid_desc_m_n(i) = ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes);
});
const auto e_grid_desc_m_n =
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
a_g_n_k_wos_lengths, a_g_n_k_wos_lengths,
a_g_n_k_wos_strides, a_g_n_k_wos_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_c_wis_lengths, e_g_n_c_wis_lengths,
e_g_n_c_wis_strides, ds_g_n_c_wis_strides[i],
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
tildes); tildes};
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
});
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// for check validity // for check validity
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
...@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
BElementwiseOp b_element_op_; BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_; CDEElementwiseOp cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_; std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_; std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
std::array<index_t, NDimSpatial> conv_filter_strides_; std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> conv_filter_dilations_;
std::array<index_t, NDimSpatial> input_left_pads_; std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_; std::array<index_t, NDimSpatial> input_right_pads_;
}; };
......
...@@ -111,8 +111,7 @@ __global__ void ...@@ -111,8 +111,7 @@ __global__ void
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block) [[maybe_unused]] const index_t num_k_per_block)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
......
...@@ -38,8 +38,7 @@ __global__ void ...@@ -38,8 +38,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
......
...@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, fp8_storage_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
...@@ -843,8 +845,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -843,8 +845,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#else #else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t(0);
#endif #endif
} }
...@@ -873,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -873,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(customized_value); return src_thread_element_valid ? tmp : vector_t(customized_value);
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment