Commit 2b8a9941 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents ce9d7c8d 707ad002
...@@ -49,6 +49,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -49,6 +49,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
vim \ vim \
nano \ nano \
zlib1g-dev \ zlib1g-dev \
zip \
openssh-server \ openssh-server \
clang-format-12 \ clang-format-12 \
kmod && \ kmod && \
......
...@@ -526,6 +526,26 @@ def Build_CK(Map conf=[:]){ ...@@ -526,6 +526,26 @@ def Build_CK(Map conf=[:]){
stash "ckprofiler_0.2.0_amd64.deb" stash "ckprofiler_0.2.0_amd64.deb"
} }
} }
if (params.hipTensor_test && navi_node == 0 ){
//build and test hipTensor
sh """#!/bin/bash
rm -rf "${params.hipTensor_branch}".zip
rm -rf hipTensor-"${params.hipTensor_branch}"
wget https://github.com/ROCmSoftwarePlatform/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip
unzip -o "${params.hipTensor_branch}".zip
"""
dir("hipTensor-${params.hipTensor_branch}"){
sh """#!/bin/bash
mkdir -p build
ls -ltr
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="/opt/rocm;${env.WORKSPACE}/install"
cmake --build build -- -j
"""
}
dir("hipTensor-${params.hipTensor_branch}/build"){
sh 'ctest'
}
}
} }
} }
} }
...@@ -654,6 +674,15 @@ pipeline { ...@@ -654,6 +674,15 @@ pipeline {
name: "DL_KERNELS", name: "DL_KERNELS",
defaultValue: false, defaultValue: false,
description: "Select whether to build DL kernels (default: OFF)") description: "Select whether to build DL kernels (default: OFF)")
booleanParam(
name: "hipTensor_test",
defaultValue: true,
description: "Use the CK build to verify hipTensor build and tests (default: ON)")
string(
name: 'hipTensor_branch',
defaultValue: 'mainline',
description: 'Specify which branch of hipTensor to use (default: mainline)')
} }
environment{ environment{
dbuser = "${dbuser}" dbuser = "${dbuser}"
......
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations)
\ No newline at end of file
...@@ -42,7 +42,7 @@ fastjsonschema==2.18.0 ...@@ -42,7 +42,7 @@ fastjsonschema==2.18.0
# via rocm-docs-core # via rocm-docs-core
gitdb==4.0.10 gitdb==4.0.10
# via gitpython # via gitpython
gitpython==3.1.31 gitpython==3.1.35
# via rocm-docs-core # via rocm-docs-core
idna==3.4 idna==3.4
# via requests # via requests
...@@ -103,7 +103,7 @@ requests==2.28.2 ...@@ -103,7 +103,7 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core>=0.20.0 rocm-docs-core==0.24.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
struct AddScale
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ constexpr void
operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const
{
const auto a0_v_t = ck::vector_type<ck::half_t, 4>{a0};
const auto a1_v_t = ck::vector_type<ck::half_t, 4>{a1};
auto r_v_t = ck::vector_type<ck::half_t, 4>{};
r_v_t.AsType<ck::half_t>()(I0) =
scale * (a0_v_t.AsType<ck::half_t>()[I0] + a1_v_t.AsType<ck::half_t>()[I0]);
r_v_t.AsType<ck::half_t>()(I1) =
scale * (a0_v_t.AsType<ck::half_t>()[I1] + a1_v_t.AsType<ck::half_t>()[I1]);
r_v_t.AsType<ck::half_t>()(I2) =
scale * (a0_v_t.AsType<ck::half_t>()[I2] + a1_v_t.AsType<ck::half_t>()[I2]);
r_v_t.AsType<ck::half_t>()(I3) =
scale * (a0_v_t.AsType<ck::half_t>()[I3] + a1_v_t.AsType<ck::half_t>()[I3]);
a = r_v_t.AsType<ck::half4_t>()[I0];
}
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const
{
a = scale * (a0 + a1);
}
// this attribute controls the copy_function applying element_wise_op with
// pack4_data
constexpr const static bool is_pack4_invocable = true;
float scale = 1.0;
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
ck::half_t& e, const float& c, const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
};
float alpha_;
float beta_;
};
using AElementOp = AddScale;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle<
ck::Tuple<ALayout, ALayout>,
ck::Tuple<BLayout>,
ck::Tuple<DLayout>,
ELayout,
ck::Tuple<ADataType, ADataType>,
ck::Tuple<BDataType>,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
float alpha = 1.0f;
float beta = 1.0f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]);
beta = std::stof(argv[5]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
alpha = std::stof(argv[11]);
beta = std::stof(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
"beta\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
a1_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
a1_device_buf.ToDevice(a1_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{0.2};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, 2>{StrideA, StrideA},
std::array<ck::index_t, 1>{StrideB},
std::array<ck::index_t, 1>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<ADataType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
PassThrough,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
...@@ -20,7 +20,8 @@ template <typename ALayout, ...@@ -20,7 +20,8 @@ template <typename ALayout,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
typename ComputeType = CDataType>
struct DeviceGemmSplitK : public BaseOperator struct DeviceGemmSplitK : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -48,7 +49,8 @@ template <typename ALayout, ...@@ -48,7 +49,8 @@ template <typename ALayout,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
typename ComputeType = CDataType>
using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout, using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout, ...@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>>; CElementwiseOperation,
ComputeType>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -8,57 +8,6 @@ namespace ck { ...@@ -8,57 +8,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed
/// to the GroupedGEMM entry point kernel.
///
struct GroupedGemmKernelArguments
{
__host__ __device__ GroupedGemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
...@@ -82,28 +31,7 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout, ...@@ -82,28 +31,7 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
//---------------------------------------------------------------------------------------------- virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize([[maybe_unused]] BaseArgument* p_arg,
[[maybe_unused]] index_t kbatch) const
{
}
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs([[maybe_unused]] BaseArgument* p_arg,
[[maybe_unused]] const void* p_dev_kernel_args) const
{
}
}; };
} // namespace device } // namespace device
......
...@@ -22,22 +22,22 @@ template <typename InDataType, ...@@ -22,22 +22,22 @@ template <typename InDataType,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceSoftmax : public BaseOperator struct DeviceSoftmax : public BaseOperator
{ {
/// //
/// @brief Makes a pointer to Argument class. // @brief Makes a pointer to Argument class.
/// //
/// @param[in] inLengths Input tensor extent(s) from high to low dimension // @param[in] inLengths Input tensor extent(s) from high to low dimension
/// @param[in] inStrides Input tensor stride(s) from high to low dimension // @param[in] inStrides Input tensor stride(s) from high to low dimension
/// @param[in] reduceDims The dimension(s) the normalization operation is applied // @param[in] reduceDims The dimension(s) the normalization operation is applied
/// @param[in] alpha double type value // @param[in] alpha double type value
/// @param[in] beta double type value // @param[in] beta double type value
/// @param[in] in_dev Typeless const pointer in device memory storing the input // @param[in] in_dev Typeless const pointer in device memory storing the input
/// tensor // tensor
/// @param out_dev Typeless pointer in device memory storing the output tensor // @param out_dev Typeless pointer in device memory storing the output tensor
/// @param[in] in_elementwise_op The input elementwise operation. // @param[in] in_elementwise_op The input elementwise operation.
/// @param[in] acc_elementwise_op The accumulation elementwise operation. // @param[in] acc_elementwise_op The accumulation elementwise operation.
/// //
/// @return Unique pointer to the Argument class. // @return Unique pointer to the Argument class.
/// //
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
......
...@@ -69,7 +69,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -69,7 +69,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
ComputeType>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -168,7 +169,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -168,7 +169,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
stream_config.stream_id_)); stream_config.stream_id_));
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
......
...@@ -157,22 +157,22 @@ __global__ void ...@@ -157,22 +157,22 @@ __global__ void
} }
} // namespace } // namespace
/// //
/// @brief Device Convolution operation. // @brief Device Convolution operation.
/// //
/// Supports: // Supports:
/// @li Forward convolution with up to 3 spatial dimentions // @li Forward convolution with up to 3 spatial dimentions
/// @li Input tensor in GNWC data format // @li Input tensor in GNWC data format
/// @li Weight tensor in GKXC data format // @li Weight tensor in GKXC data format
/// @li Output tensor in GNWK data format // @li Output tensor in GNWK data format
/// //
/// 1D: // 1D:
/// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] // out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
/// 2D: // 2D:
/// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
/// 3D: // 3D:
/// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
/// //
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
......
...@@ -154,22 +154,22 @@ __global__ void ...@@ -154,22 +154,22 @@ __global__ void
} // namespace } // namespace
/// //
/// @brief Device Convolution operation. // @brief Device Convolution operation.
/// //
/// Supports: // Supports:
/// @li Forward convolution with up to 3 spatial dimentions // @li Forward convolution with up to 3 spatial dimentions
/// @li Input tensor in GNWC data format // @li Input tensor in GNWC data format
/// @li Weight tensor in GKXC data format // @li Weight tensor in GKXC data format
/// @li Output tensor in GNWK data format // @li Output tensor in GNWK data format
/// //
/// 1D: // 1D:
/// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] // out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
/// 2D: // 2D:
/// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
/// 3D: // 3D:
/// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
/// //
template < template <
index_t NDimSpatial, index_t NDimSpatial,
typename ADataType, typename ADataType,
......
...@@ -150,22 +150,22 @@ __global__ void ...@@ -150,22 +150,22 @@ __global__ void
} // namespace } // namespace
/// //
/// @brief Device Convolution operation. // @brief Device Convolution operation.
/// //
/// Supports: // Supports:
/// @li Forward convolution with up to 3 spatial dimentions // @li Forward convolution with up to 3 spatial dimentions
/// @li Input tensor in GNWC data format // @li Input tensor in GNWC data format
/// @li Weight tensor in GKXC data format // @li Weight tensor in GKXC data format
/// @li Output tensor in GNWK data format // @li Output tensor in GNWK data format
/// //
/// 1D: // 1D:
/// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] // out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
/// 2D: // 2D:
/// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
/// 3D: // 3D:
/// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
/// //
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
......
...@@ -5,13 +5,11 @@ ...@@ -5,13 +5,11 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <tuple>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
...@@ -25,28 +23,8 @@ namespace ck { ...@@ -25,28 +23,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] tile_count The overall number of output tiles we divided all groups
/// into.
/// @param[in] k_batch The number of batches we split the K dimension into.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for groupd gemm calculation and work
/// distribution.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
/// @tparam CGlobalMemoryDataOperation The functor used to store data in output C matrix.
/// In example could be: AtomicAdd or Store.
///
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void __global__ void
...@@ -54,99 +32,42 @@ __global__ void ...@@ -54,99 +32,42 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t tile_count, const index_t group_count)
const index_t k_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
index_t tile_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const index_t grid_size = get_grid_size();
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock(); index_t left = 0;
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock(); index_t right = group_count;
static constexpr index_t B2E_M01 = 8; index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; block_id < gemm_desc_ptr[group_id].block_end_)) &&
using Block2ETileMapKSplit = left <= right)
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
index_t group_id = 0;
index_t offset = 0;
auto M = gemm_desc_ptr[group_id].M;
auto N = gemm_desc_ptr[group_id].N;
auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
auto b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
index_t grid_size_grp = b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = grid_size_grp;
while(tile_id < tile_count)
{ {
// Find corresponding GEMM group for out tile if(block_id < gemm_desc_ptr[group_id].block_start_)
while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end))
{ {
offset += grid_size_grp; right = group_id;
group_id++;
M = gemm_desc_ptr[group_id].M;
N = gemm_desc_ptr[group_id].N;
StrideC = gemm_desc_ptr[group_id].StrideC;
c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
grid_size_grp = b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
gemm_tile_id_start = offset;
gemm_tile_id_end = offset + grid_size_grp;
} }
else
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid); {
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid); left = group_id;
const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid); }
group_id = index_t((left + right) / 2);
const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
const auto MPadded = GridwiseGemm::CalculateMPadded(M);
const auto NPadded = GridwiseGemm::CalculateNPadded(N);
const auto KPadded = GridwiseGemm::CalculateKPadded(K, k_batch);
const auto K0 = GridwiseGemm::CalculateK0(K, k_batch);
LocalBlockToCTileMap<Block2ETileMapKSplit> local_b2c{b2c_tile_map, tile_id - offset};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
p_a_grid,
p_b_grid,
p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
local_b2c);
tile_id += grid_size;
} }
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = tile_count; ignore = group_count;
ignore = k_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -265,13 +186,33 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -265,13 +186,33 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using GridwiseGemmArg = typename GridwiseGemm::Argument;
using KernelArguments = GroupedGemmKernelArguments;
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>; BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter. // Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8; static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
struct GemmTransKernelArg
{
KernelArgument karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start,
index_t block_end)
: karg_{karg},
block_2_ctile_map_{b2c_map},
block_start_{block_start},
block_end_{block_end}
{
}
};
static constexpr index_t DefaultKBatch = 1; static constexpr index_t DefaultKBatch = 1;
// Argument // Argument
...@@ -284,6 +225,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -284,6 +225,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<GemmDesc>& gemm_descs) std::vector<GemmDesc>& gemm_descs)
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
{ {
// TODO: use occupancy api to calculate appropriate batch size.
} }
Argument(std::vector<const void*>& p_As, Argument(std::vector<const void*>& p_As,
...@@ -291,8 +233,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -291,8 +233,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<void*>& p_Es, std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
index_t kbatch) index_t kbatch)
: K_BATCH{kbatch}, group_count_{0}, skipped_group_count_{0}, grid_size_{0} : K_BATCH{kbatch}
{ {
grid_size_ = 0;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) && if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
...@@ -304,6 +247,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -304,6 +247,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
gemm_kernel_args_.reserve(group_count_); gemm_kernel_args_.reserve(group_count_);
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); ++i) for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{ {
const index_t M = gemm_descs[i].M_; const index_t M = gemm_descs[i].M_;
...@@ -320,29 +265,51 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -320,29 +265,51 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
gemm_kernel_args_.emplace_back(type_convert<const ADataType*>(p_As[i]), // block-to-e-tile map
type_convert<const BDataType*>(p_Bs[i]), auto grouped_block_2_ctile_map =
type_convert<EDataType*>(p_Es[i]), GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
M,
N, auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]),
K, type_convert<const BDataType*>(p_Bs[i]),
stride_a, type_convert<EDataType*>(p_Es[i]),
stride_b, M,
stride_c); N,
K,
stride_a,
stride_b,
stride_c,
m_padded,
n_padded,
k_padded,
k0,
K_BATCH};
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
} }
} }
/// /**
/// @brief Set new kbatch value. * @brief Recalculate group grid size for all gemms and update B2C maps.
/// *
/// @param[in] kbatch The new splitK parameter value. * @param[in] kbatch The new splitK parameter value.
/// */
void UpdateKBatch(index_t kbatch) void UpdateKBatch(index_t kbatch)
{ {
K_BATCH = kbatch; K_BATCH = kbatch;
...@@ -351,14 +318,33 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -351,14 +318,33 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{ {
auto& gemm_arg = gemm_kernel_args_[i]; auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(gemm_arg.M, gemm_arg.N, gemm_arg.StrideC); GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0;
karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
} }
} }
...@@ -366,167 +352,31 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -366,167 +352,31 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
index_t K_BATCH; index_t K_BATCH;
index_t group_count_; index_t group_count_;
index_t skipped_group_count_; index_t skipped_group_count_;
// The overall number of output tiles to be processed.
index_t grid_size_;
const void* p_dev_gemm_args_;
std::vector<KernelArguments> gemm_kernel_args_; std::vector<GemmTransKernelArg> gemm_kernel_args_;
index_t grid_size_;
}; };
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
static constexpr int CU_SIMDS = 4;
// Assume we want to have at most 2 waves per SIMD
static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] dev_gemm_args The point to device memory with kernel arguments.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
0,
gemm_arg.M * gemm_arg.N * sizeof(EDataType),
stream_config.stream_id_));
}
}
float ave_time = 0;
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, dev_gemm_args, stream_config);
}
else
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, true>(
arg, dev_gemm_args, stream_config);
}
}
else
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, dev_gemm_args, stream_config);
}
else
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, false>(
arg, dev_gemm_args, stream_config);
}
}
return ave_time;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device workspace buffer for kernel
/// arguments. The user should call @see GetWorkSpaceSize and @see
/// SetWorkSpacePointer on arg parameter to properly allocate this buffer.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(arg.p_workspace_ != nullptr) index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
{ bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(KernelArguments),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return Run(arg, arg.p_workspace_, stream_config);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{
index_t K0 = GridwiseGemm::CalculateK0(arg.gemm_kernel_args_[0].K, arg.K_BATCH);
bool all_have_kbatch_gt_one = arg.K_BATCH > 1;
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
const auto& gemm_arg = arg.gemm_kernel_args_[i]; const auto& karg = arg.gemm_kernel_args_[i].karg_;
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
gemm_arg.Print(); karg.Print();
} }
// Currently all groups use same kbatch value. auto kbatch = karg.k_batch;
auto kbatch = arg.K_BATCH;
K0 = GridwiseGemm::CalculateK0(arg.gemm_kernel_args_[i].K, arg.K_BATCH); if(!GridwiseGemm::CheckValidity(karg))
if(!GridwiseGemm::CheckValidity(GridwiseGemmArg{nullptr,
nullptr,
nullptr,
gemm_arg.M,
gemm_arg.N,
gemm_arg.K,
gemm_arg.StrideA,
gemm_arg.StrideB,
gemm_arg.StrideC,
0, // MPadded
0, // NPadded
0, // KPadded
K0,
kbatch}))
{ {
std::ostringstream err; std::ostringstream err;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
...@@ -534,6 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -534,6 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
K0 = karg.K0;
bool not_all_have_main_k0_block_loop_same = bool not_all_have_main_k0_block_loop_same =
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0); all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
...@@ -551,75 +402,99 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -551,75 +402,99 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::ostringstream err; std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! " err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << kbatch << "group [" << i << "], kbatch: " << kbatch
<< ", group [0], kbatch: " << arg.K_BATCH << " in " << __FILE__ << ":" << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
<< __LINE__ << ", in function: " << __func__; << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k0_block_loop);
}
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop> hip_check_error(
float DispatchKernel(const Argument& arg, hipMemcpyWithStream(arg.p_workspace_,
const void* dev_gemm_args, arg.gemm_kernel_args_.data(),
const StreamConfig& stream_config) const arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
{ hipMemcpyHostToDevice,
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm, stream_config.stream_id_));
KernelArguments,
ADataType,
BDataType,
EDataType,
HasMainKBlockLoop,
CGlobalMemoryDataOperation>;
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
}
template <typename KernelFunction> float ave_time = 0;
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
const StreamConfig& stream_config) const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int num_blocks = 0;
size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
int cu_count = getAvailableComputeUnitCount(stream_config); const auto Run = [&](const auto& kernel) {
if(all_have_kbatch_gt_one)
{
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
const auto& karg = trans_arg.karg_;
hip_check_error(hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
}
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
};
if(stream_config.log_level_ > 0) if(all_have_main_k0_block_loop)
{ {
std::cout << "MaxActiveBlocksPerCU: " << num_blocks if(all_have_kbatch_gt_one)
<< ", available CUs count: " << cu_count << ", occup. grid size: " {
<< ck::math::min(num_blocks, CU_BLOCKS) * cu_count * const auto kernel =
BLOCK_SUBSCRIPTION_FACTOR kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
<< std::endl; GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
else
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
false,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
false,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
} }
return cu_count * ck::math::min(num_blocks, CU_BLOCKS) * BLOCK_SUBSCRIPTION_FACTOR; return ave_time;
} }
template <typename KernelFunction> // polymorphic
float LaunchKernel(const KernelFunction& kernel, float Run(const BaseArgument* p_arg,
const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) override
const void* dev_gemm_args,
const StreamConfig& stream_config) const
{ {
int max_occupancy_grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
// We launch the smaller number of workgroups from acutally needed tiles and the
// number of workgroups that maximize the GPU occupancy. That is because for some tile
// configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles.
return launch_and_time_kernel(
stream_config,
kernel,
dim3(ck::math::min(arg.grid_size_, max_occupancy_grid_size)),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
arg.grid_size_,
arg.K_BATCH);
} }
}; };
...@@ -631,6 +506,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -631,6 +506,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_) arg.skipped_group_count_) != arg.group_count_)
{ {
...@@ -645,28 +525,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -645,28 +525,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool supported = true; bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
const auto& gemm_arg = arg.gemm_kernel_args_[i]; const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto K0 = GridwiseGemm::CalculateK0(gemm_arg.K, arg.K_BATCH); bool group_arg_valid = GridwiseGemm::CheckValidity(a);
bool group_arg_valid = GridwiseGemm::CheckValidity(GridwiseGemmArg{nullptr,
nullptr,
nullptr,
gemm_arg.M,
gemm_arg.N,
gemm_arg.K,
gemm_arg.StrideA,
gemm_arg.StrideB,
gemm_arg.StrideC,
0, // MPadded
0, // NPadded
0, // KPadded
K0,
arg.K_BATCH});
if(not group_arg_valid) if(not group_arg_valid)
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout << "[" << __func__ << "] group id: " << i std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl; << " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print(); a.Print();
#endif // DEBUG_LOG #endif // DEBUG_LOG
} }
supported = supported && group_arg_valid; supported = supported && group_arg_valid;
...@@ -674,6 +540,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -674,6 +540,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return supported; return supported;
} }
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
...@@ -693,6 +560,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -693,6 +560,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As, MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs, std::vector<const void*>& p_Bs,
...@@ -706,17 +574,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -706,17 +574,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs); return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
} }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedGemm_XdlSplitKTileLoop" str << "DeviceGroupedGemm_XdlSplitK"
<< "<" << "<"
<< std::string(ALayout::name)[0] << "," << std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << "," << std::string(BLayout::name)[0] << ","
...@@ -735,9 +605,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -735,9 +605,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
<< BBlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", " << CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", "
<< ABlockTransferThreadClusterLengths_K0_M_K1{} << ", " << getGemmSpecializationString(GemmSpec)
<< getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer
<< ">"; << ">";
// clang-format on // clang-format on
...@@ -747,24 +615,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -747,24 +615,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{ {
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() * return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(KernelArguments); sizeof(GemmTransKernelArg);
} }
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
static void SetDeviceKernelArgs(Argument& arg, const void* p_dev_kernel_args)
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
// polymorphic
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{ {
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch); return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
} }
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
}; };
} // namespace device } // namespace device
......
...@@ -348,24 +348,24 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -348,24 +348,24 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
acc_elementwise_op}; acc_elementwise_op};
}; };
/// //
/// @brief Makes a pointer to Argument class. // @brief Makes a pointer to Argument class.
/// //
/// @param[in] inLengths Input tensor extent(s) from high to low dimension // @param[in] inLengths Input tensor extent(s) from high to low dimension
/// @param[in] inStrides Input tensor stride(s) from high to low dimension // @param[in] inStrides Input tensor stride(s) from high to low dimension
/// @param[in] reduceDims The dimension(s) the normalization operation is applied // @param[in] reduceDims The dimension(s) the normalization operation is applied
/// @param[in] alpha Typeless pointer in host memory storing the alpha scaling // @param[in] alpha Typeless pointer in host memory storing the alpha scaling
/// value as type AccDataType // value as type AccDataType
/// @param[in] beta Typeless pointer in host memory storing the beta scaling // @param[in] beta Typeless pointer in host memory storing the beta scaling
/// value as type AccDataType // value as type AccDataType
/// @param[in] in_dev Typeless const pointer in device memory storing the input // @param[in] in_dev Typeless const pointer in device memory storing the input
/// tensor // tensor
/// @param out_dev Typeless pointer in device memory storing the output tensor // @param out_dev Typeless pointer in device memory storing the output tensor
/// @param[in] in_elementwise_op The input elementwise operation. // @param[in] in_elementwise_op The input elementwise operation.
/// @param[in] acc_elementwise_op The accumulation elementwise operation. // @param[in] acc_elementwise_op The accumulation elementwise operation.
/// //
/// @return Unique pointer to the Argument class. // @return Unique pointer to the Argument class.
/// //
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
......
...@@ -271,8 +271,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -271,8 +271,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
{ {
} }
__host__ __device__ constexpr index_t __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
...@@ -625,35 +624,23 @@ struct OffsettedBlockToCTileMap ...@@ -625,35 +624,23 @@ struct OffsettedBlockToCTileMap
index_t block_start_; index_t block_start_;
}; };
/// /**
/// @brief Simple tile mapping which creates 3D grid of block of threads. * @brief Simple tile mapping which creates 3D grid of block of threads.
/// *
/// @paragraph Description * @paragraph Description
/// This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread * This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
/// blocks. The first 2D are regular 2D tiles created by division of output GEMM * blocks. The first 2D are regular 2D tiles created by division of output GEMM
/// dimenions by corresponding tile size. The third dimension (Z) is a k-split * dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
/// dimension, which denotes the number of blocks we use to divide work on GEMM K * which denotes the number of blocks we use to divide work on GEMM K dimension onto.
/// dimension onto. *
/// * @tparam MPerBlock Output block tile size in M dimension.
/// @tparam MPerBlock Output block tile size in M dimension. * @tparam NPerBlock Output block tile size in N dimension.
/// @tparam NPerBlock Output block tile size in N dimension. */
///
template <index_t MPerBlock, index_t NPerBlock> template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_3DGrid_KSplit struct BlockToCTileMap_3DGrid_KSplit
{ {
__host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
/// __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default;
/// @brief Constructs a new instance.
///
/// @param[in] top_idx Swallow blockIdx.
///
/// @tparam TopIdx The type of block index.
///
template <typename TopIdx>
__host__ __device__ BlockToCTileMap_3DGrid_KSplit([[maybe_unused]] TopIdx top_idx)
{
}
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
CalculateGridSize(index_t M, index_t N, index_t k_split) const CalculateGridSize(index_t M, index_t N, index_t k_split) const
...@@ -665,7 +652,8 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -665,7 +652,8 @@ struct BlockToCTileMap_3DGrid_KSplit
return std::make_tuple(N0, M0, k_split); return std::make_tuple(N0, M0, k_split);
} }
__device__ constexpr auto CalculateBottomIndex() const template <typename TopIdx>
__device__ constexpr auto CalculateBottomIndex(const TopIdx&) const
{ {
return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x); return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x);
} }
...@@ -684,53 +672,6 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -684,53 +672,6 @@ struct BlockToCTileMap_3DGrid_KSplit
} }
}; };
///
/// @brief Block to CTile Map which foster external mechanism for setting up local block id.
///
/// In example this type can be easily used to implement tile looping work distribution
/// scheme.
///
/// @tparam UnderlyingBlockToCTileMap The type of the local tile mapp.
///
template <typename UnderlyingBlockToCTileMap>
struct LocalBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ LocalBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t local_id)
: block_to_ctile_map_{block_to_ctile_map}, local_block_id_{local_id}
{
}
__host__ __device__ constexpr auto CalculateBottomIndex() const
{
return block_to_ctile_map_.CalculateBottomIndex(make_multi_index(local_block_id_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t local_block_id_;
};
enum StreamKReductionStrategy enum StreamKReductionStrategy
{ {
Atomic = 0, // sk block use atomic to do reduction Atomic = 0, // sk block use atomic to do reduction
......
...@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
[&](auto i) { [&](auto i) {
using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>; using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
return MakeAGridDescriptor_M_N<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]); return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
}, },
Number<NumATensor>{}); Number<NumATensor>{});
} }
...@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -656,7 +656,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeDataType, ComputeDataType, // ComputeDataType for A
ComputeDataType, // ComputeDataType for B
AccDataType, AccDataType,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <ostream>
#include <string>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...@@ -44,20 +42,4 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -44,20 +42,4 @@ constexpr auto GridwiseGemmPipeline_Selector()
} }
} }
inline std::string getPipelineVersionString(const PipelineVersion& pv)
{
switch(pv)
{
case PipelineVersion::v1: return "PipelineVersion::v1";
case PipelineVersion::v2: return "PipelineVersion::v2";
default: return "Unrecognized pipeline version!";
}
}
} // namespace ck } // namespace ck
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion pv)
{
os << ck::getPipelineVersionString(pv);
return os;
}
...@@ -27,7 +27,8 @@ __global__ void ...@@ -27,7 +27,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
const Block2CTileMap& b2c_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -35,12 +36,11 @@ __global__ void ...@@ -35,12 +36,11 @@ __global__ void
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
Block2CTileMap b2c_map{get_block_1d_id()};
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared), b2c_map); karg, static_cast<void*>(p_shared), b2c_map);
#else #else
ignore = karg; ignore = karg;
ignore = b2c_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -541,6 +541,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -541,6 +541,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
} }
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
...@@ -566,28 +575,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -566,28 +575,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap> typename Block2CTileMap>
__device__ static void Run(const FloatA* p_a_grid, __device__ static void Run(const Argument& karg,
const FloatB* p_b_grid,
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t MPadded,
index_t NPadded,
index_t KPadded,
index_t K0,
index_t k_batch,
void* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_b_k0_m_k1_grid_desc = const FloatA* p_a_grid = karg.p_a_grid;
MakeAGridDescriptor_KBatch_K0_M_K1(M, MPadded, K, StrideA, k_batch, K0, KPadded); const FloatB* p_b_grid = karg.p_b_grid;
const auto b_b_k0_n_k1_grid_desc = FloatC* p_c_grid = karg.p_c_grid;
MakeBGridDescriptor_KBatch_K0_N_K1(K, NPadded, N, StrideB, k_batch, K0, KPadded); const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(M, N, StrideC); karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
...@@ -603,7 +602,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -603,7 +602,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [KBatch, M, N] // divide block work by [KBatch, M, N]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(); const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
...@@ -1010,34 +1010,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -1010,34 +1010,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
} }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap>
__device__ static void Run(const Argument& karg,
void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, Block2CTileMap>(karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
karg.MPadded,
karg.NPadded,
karg.KPadded,
karg.K0,
karg.k_batch,
p_shared_block,
block_2_ctile_map);
}
static constexpr auto GetMPerBlock() { return MPerBlock; }
static constexpr auto GetNPerBlock() { return NPerBlock; }
static std::string GetTypeString() static std::string GetTypeString()
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp" #include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck { namespace ck {
...@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto src_vector_container = src_vector_type{ auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v;
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
});
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>( .template SetAsType<dst_vector_t>(src_data_idx_seq,
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]); op_r_v.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{ {
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
}); });
#else #else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim && if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<SrcData>>::value && ((is_same<half_t, remove_cvref_t<DstData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<SrcData>>::value && (is_same<int8_t, remove_cvref_t<DstData>>::value &&
is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
// each transpose does // each transpose does
...@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto data_idx_seq = generate_sequence_v2( constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{}); [&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>; using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from // get DstScalarPerVector # of read-only references to src vectors from
...@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number<num_dst_vector>{}); Number<num_dst_vector>{});
// do data transpose // do data transpose
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}( transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs); src_vector_refs, dst_vector_refs);
}); });
} }
else
static_ford<SliceLengths>{}([&](auto idx) { {
// apply the src elementwise op and convert to DstData under the hood if needed static_ford<SliceLengths>{}([&](auto idx) {
DstData dst_v; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); });
dst_thread_scratch_(idx) = dst_v; }
});
#endif #endif
} }
...@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using SrcThreadScratch =
SrcData, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcScalarPerVector, DstData, // apply data_convert with SrcThreadScratch
decltype(src_thread_scratch_desc_), SrcScalarPerVector,
true>; decltype(src_thread_scratch_desc_),
true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment