Commit 15baccf2 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 5029a5a4 a328df25
File mode changed from 100644 to 100755
...@@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) ...@@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3)
add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
......
...@@ -7,3 +7,21 @@ ...@@ -7,3 +7,21 @@
#arg3: run kernel # of times (>1) #arg3: run kernel # of times (>1)
./bin/example_gemm_xdl 0 1 5 ./bin/example_gemm_xdl 0 1 5
``` ```
# Instructions for ```example_gemm_xdl_fp16_streamk_v3```
## Run ```example_gemm_xdl_fp16_streamk_v3```
```bash
arg1: verification (0=no, 1=yes)
arg2: initialization (0=no init, 1=integer value, 2=decimal value)
arg3: time kernel (0=no, 1=yes)
arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)
arg11: Grid_size(-1 for max occupancy)
bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1}
Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal<MNPadding, RRR> BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2
```
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -45,6 +45,19 @@ struct ProblemSizeStreamK final ...@@ -45,6 +45,19 @@ struct ProblemSizeStreamK final
ck::index_t NumSKBlocks = -1; ck::index_t NumSKBlocks = -1;
}; };
struct ProblemSizeStreamK_universal final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
};
struct ProblemSizeSplitK final struct ProblemSizeSplitK final
{ {
...@@ -123,6 +136,57 @@ bool parse_cmd_args<ProblemSize>(int argc, ...@@ -123,6 +136,57 @@ bool parse_cmd_args<ProblemSize>(int argc,
return true; return true;
} }
template <>
bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
char* argv[],
ProblemSizeStreamK_universal& problem_size,
ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc >= 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
if(argc >= 11)
{
problem_size.Streamk_sel = std::stoi(argv[10]);
problem_size.Grid_size = std::stoi(argv[11]);
}
}
else
{
std::cerr
<< "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
return false;
}
return true;
}
template <> template <>
bool parse_cmd_args<ProblemSizeStreamK>(int argc, bool parse_cmd_args<ProblemSizeStreamK>(int argc,
char* argv[], char* argv[],
...@@ -165,7 +229,8 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc, ...@@ -165,7 +229,8 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
<< std::endl << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg10: NumSKBlocks(optional)" << std::endl; << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
return false; return false;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmV2_Streamk_Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
224, 256,
64, 8, 2,
16, 16,
7, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 2, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example_streamk_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto Grid_size = problem_size.Grid_size;
auto Streamk_sel = problem_size.Streamk_sel;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
auto f_get_default_streamk_policy = [](ck::index_t streamk_sel) {
if(streamk_sel == -1)
{
return static_cast<std::size_t>(4);
}
else
return static_cast<std::size_t>(streamk_sel);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Streamk_sel = f_get_default_streamk_policy(Streamk_sel);
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2_Streamk_Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
Streamk_sel,
Grid_size,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#endif
}
if(config.time_kernel)
{
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
bool run_gemm_universal_streamk_example(int argc, char* argv[])
{
ProblemSizeStreamK_universal problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
...@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach() endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list #Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ") message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
...@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach() endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list #Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ") message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
......
...@@ -271,7 +271,9 @@ class FmhaBwdApiPool: ...@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T # GEMM0: Q@K=S^T
......
...@@ -278,6 +278,9 @@ class FmhaFwdApiPool: ...@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass @dataclass
......
...@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool: ...@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
@dataclass @dataclass
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm_Streamk_V2 : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t Streamk_sel,
ck::index_t Grid_size,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK
} }
}; };
template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
uint32_t TileSwizzleSubM_ = 8,
index_t GroupNum = 8,
index_t M01_ = 4>
struct BlockToCTileMap_GemmStreamK_v2
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
mutable uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t dp_start_block_idx;
uint32_t reduction_start_block_idx;
uint32_t k_iters_per_big_block;
MDiv2 n_tiles;
MDiv k_iters_per_tile;
MDiv equiv_tiles_big; // for reduction
MDiv equiv_tiles_little; // for reduction
// prefer construct on host
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
{
// total output tiles
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
// default to regular DP GEMM if sk blocks == 0
if(streamk_sel == 0)
{
sk_num_blocks = 0;
dp_tiles = num_tiles;
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
// 2-tile sk + DP GEMM
else
{
// check if there's enough work for DP+ stream-k
bool bigEnough = num_tiles > grid_size;
// select between stream-k strategies
uint32_t sk_tiles = 0;
if(streamk_sel == 1) // 1 tile stream-k
{
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
}
else if(streamk_sel == 2) // 2-tile stream-k
{
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
}
else if(streamk_sel == 3) // 3-tile stream-k
{
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
: num_tiles;
}
else if(streamk_sel == 4) // 4-tile stream-k
{
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
: num_tiles;
}
sk_num_blocks = sk_tiles;
// remaining tiles are DP tiles
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = sk_num_blocks;
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
// using multiple blocks for parallel reduction
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
__host__ __device__ uint32_t get_sk_total_iters() const
{
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
return sk_total_iters;
}
__host__ __device__ uint32_t get_sk_tiles() const
{
// tiles for sk
uint32_t sk_total_iters = get_sk_total_iters();
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ index_t get_grid_dims() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
return reduction_start_block_idx + get_sk_tiles();
}
else
return reduction_start_block_idx;
}
__device__ uint32_t get_block_idx() const
{
// TODO: swizzle block index for better locality
return __builtin_amdgcn_readfirstlane(blockIdx.x);
}
__device__ void
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
{
if(block_idx < sk_num_big_blocks)
{
iter_start = block_idx * k_iters_per_big_block;
iter_end = iter_start + k_iters_per_big_block;
}
else if(block_idx < sk_num_blocks)
{
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
iter_end = iter_start + (k_iters_per_big_block - 1);
}
else if(block_idx >= dp_start_block_idx)
{
uint32_t sk_total_iters = get_sk_total_iters();
uint32_t dp_iters_per_block = k_iters_per_tile.get();
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
}
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
uint32_t iter_end,
uint32_t total_iter_length) const
{
uint32_t iter_length_mod, iter_length_quo /*unused*/;
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
uint32_t current_iter_length = math::min(
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
return current_iter_length;
}
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
__device__ void
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
{
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
}
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
{
uint32_t m_tile_idx, n_tile_idx;
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
// // swizzle tile
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
? tile_swizzle_sub_m
: tile_swizzle_sub_m_rem;
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
}
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes =
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
{
return get_sk_tiles() * sizeof(uint32_t);
}
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
{
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
const MDiv& equiv_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
uint32_t quo_, rem_;
equiv_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_equiv_tiles_ + rem_;
}
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
uint32_t iters_per_sk_block_) const
{
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1);
}
__host__ __device__ uint32_t get_total_acc_buffers() const
{
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
uint32_t tiles_cover_little_blocks =
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little;
}
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
{
// TODO: from big to little
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
if(tile_idx_ < tiles_cover_big_blocks)
{
uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
return touched_sk_blocks + current_intersec;
}
else
{
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
uint32_t touched_sk_blocks =
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block;
uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
}
}
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
{
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
if(block_idx_ < sk_num_big_blocks)
{
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
return block_idx_ + current_intersec;
}
else
{
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
}
}
};
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_smfmac.hpp"
namespace ck {
enum struct SmfmacInstr
{
smfmac_f32_16x16x32f16 = 0,
smfmac_f32_32x32x16f16,
smfmac_f32_16x16x32bf16,
smfmac_f32_32x32x16bf16,
};
template <SmfmacInstr instr>
struct smfmac_type;
template <>
struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <>
struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <>
struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <>
struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
struct SmfmacSelector
{
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
static constexpr auto GetSmfmac();
template <>
static constexpr auto GetSmfmac<half_t, 16, 16>()
{
return SmfmacInstr::smfmac_f32_16x16x32f16;
}
template <>
static constexpr auto GetSmfmac<half_t, 32, 32>()
{
return SmfmacInstr::smfmac_f32_32x32x16f16;
}
template <>
static constexpr auto GetSmfmac<bhalf_t, 16, 16>()
{
return SmfmacInstr::smfmac_f32_16x16x32bf16;
}
template <>
static constexpr auto GetSmfmac<bhalf_t, 32, 32>()
{
return SmfmacInstr::smfmac_f32_32x32x16bf16;
}
static constexpr auto selected_smfmac =
smfmac_type<GetSmfmac<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
__host__ __device__ constexpr SmfmacSelector()
{
static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk ==
selected_smfmac.num_regs_per_blk,
"wrong! num_regs_per_blk");
static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk,
"n_per_blk != num_threads_per_blk");
static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks ==
selected_smfmac.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks ||
selected_smfmac.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size ==
selected_smfmac.m_per_blk * selected_smfmac.n_per_blk,
"num_regs_per_blk incorrect");
static_assert(selected_smfmac.is_k_reduction ||
(selected_smfmac.num_input_blks == selected_smfmac.num_output_blks),
"is_k_reduction wrong!");
}
static constexpr index_t GetKPerXdlops()
{
return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) *
selected_smfmac.k_per_blk;
}
static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; }
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
index_t KPack,
typename additional_type = base_type>
struct SparseXdlopsGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks);
}
__host__ __device__ constexpr SparseXdlopsGemm()
{
static_assert(NPerXdlops == 16 || NPerXdlops == 32,
"Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops");
static_assert(MPerXdlops == 16 || MPerXdlops == 32,
"Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops");
static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<smfmac_instr.num_groups_per_blk>{},
Number<smfmac_instr.num_input_blks>{},
Number<smfmac_instr.group_size>{})),
make_pass_through_transform(Number<smfmac_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
}
template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
{
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
return transform_tensor_descriptor(
c_desc_g_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(G),
make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(smfmac_instr.num_groups_per_blk,
smfmac_instr.num_input_blks,
smfmac_instr.group_size)),
make_pass_through_transform(smfmac_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{},
Sequence<8>{}));
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / smfmac_instr.wave_size;
}
__device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; }
template <class FloatA, class FloatB, class Idx, class FloatC>
__device__ void
Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value,
"base base_type must be half or bfloat16!");
static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
smfmac_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], idx[k], p_c_thread);
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; }
__device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto blk_idx =
threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
const auto blk_id = blk_idx[I1];
const auto blk_td = blk_idx[I2];
return make_tuple(blk_id, blk_td);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(smfmac_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(smfmac_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td;
index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size;
return CIndex{m_offset, n_offset};
}
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto smfmac =
SmfmacSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
static constexpr auto smfmac_instr = smfmac.selected_smfmac;
static constexpr auto KPerXdlops = smfmac.GetKPerXdlops();
static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
{
return make_tuple(
Number<smfmac_instr.num_groups_per_blk>{}, I1, Number<smfmac_instr.group_size>{}, I1);
}
};
} // namespace ck
...@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() ...@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
" ::); " ::);
} }
CK_TILE_DEVICE void s_nop() CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{ {
#if 1 #if 1
asm volatile("\ asm volatile("s_nop %0" : : "n"(cnt) :);
s_nop 0 \n \
" ::);
#else #else
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(cnt);
#endif #endif
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#define __gfx12__ #define __gfx12__
#endif #endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -147,6 +148,14 @@ ...@@ -147,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG #ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0 #define CK_TILE_DEBUG_LOG 0
#endif #endif
......
...@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic, ...@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::generic; return address_space_enum::generic;
...@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global, ...@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr; T* p_data_ = nullptr;
BufferSizeType buffer_size_; BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0}; remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view() CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{} : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size, BufferSizeType buffer_size,
T invalid_element_value) T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} : p_data_{p_data},
buffer_size_{buffer_size},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{ {
} }
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::global; return address_space_enum::global;
...@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global, ...@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const index_t i,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{ {
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global, ...@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>( amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, p_data_, i, buffer_size_, is_valid_element); dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const index_t i,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{ {
// X is vector of T // X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global, ...@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>( amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_); smem, cached_buf_res_, i, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
...@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds, ...@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::lds; return address_space_enum::lds;
...@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::vgpr; return address_space_enum::vgpr;
......
...@@ -36,30 +36,37 @@ template <typename T, ...@@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord> index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile, async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window) NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
return tile_window.async_load(lds_tile); return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment