Unverified Commit cac014f1 authored by Anthony Chang's avatar Anthony Chang Committed by GitHub
Browse files

Fused attention (#345)



* initial stub for gemm_gemm_xdl_cshuffle

* set up example code

* compiles

* prevent integer overflow

* harmonize interface between ref_gemm and ref_batched_gemm

* batched_gemm_gemm

* fix example

* host tensor gen: diagonal pattern in lowest two-dimensions only

* make c descriptors containing only integral constants

* clean up

* add BlockwiseGemmXdlops_v2 while exploring an unified approach

* implement proper interface

* tidy up example

* fix compilation warnings

* coarsely controlled 2nd gemm padding

* remove rocm-cmake's hard requirement for certain revision

* clang-format

* resolve merge conflict

* fix compilation error on gfx10

* adds acc0 elementwise op to interface

* attention host validation

* add blockwsie softmax v1

* iteratively update softmax+gemm

* transpose both gemm0 and gemm1 xdl output so as to avoid broadcasting softmax max/sum

* add init method for easier debugging

* do away with manual thread cluster calculation

* generalize blockwise softmax interface

* row-wise softmax sum & max

* format

* rename to DeviceBatchedGemmSoftmaxGemm

* add gemm_softmax_gemm instances and tests

* comment
Co-authored-by: default avatarltqin <letao.qin@amd.com>
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent a670a5a0
...@@ -8,7 +8,7 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") ...@@ -8,7 +8,7 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
enable_testing() enable_testing()
set(ROCM_SYMLINK_LIBS OFF) set(ROCM_SYMLINK_LIBS OFF)
find_package(ROCM 0.8 REQUIRED PATHS /opt/rocm) find_package(ROCM REQUIRED PATHS /opt/rocm)
include(ROCMInstallTargets) include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers) include(ROCMPackageConfigHelpers)
......
...@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc ...@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance =
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
CDataType,
ReduceAccDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -51,8 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermu ...@@ -51,8 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermu
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
ReferenceBatchedGemm<ADataType, BDataType, EDataType, AElementOp, BElementOp, CDEElementOp>; BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
# TODO: add example batched_gemm_gemm_xdl_fp16
add_example_executable(example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|------------|
Gemm0
|---------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 17)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
exit(0);
}
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
if(do_verification)
{
// Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1;
}
return 0;
}
...@@ -46,3 +46,4 @@ add_subdirectory(28_grouped_gemm_bias_e_permute) ...@@ -46,3 +46,4 @@ add_subdirectory(28_grouped_gemm_bias_e_permute)
add_subdirectory(29_batched_gemm_bias_e_permute) add_subdirectory(29_batched_gemm_bias_e_permute)
add_subdirectory(30_grouped_convnd_fwd_bias_relu) add_subdirectory(30_grouped_convnd_fwd_bias_relu)
add_subdirectory(31_grouped_convnd_fwd_bias_relu_add) add_subdirectory(31_grouped_convnd_fwd_bias_relu_add)
add_subdirectory(32_batched_gemm_gemm)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/tensor_description/multi_index_transform.hpp" #include "ck/tensor_description/multi_index_transform.hpp"
namespace ck { namespace ck {
...@@ -159,6 +160,12 @@ struct TensorDescriptor ...@@ -159,6 +160,12 @@ struct TensorDescriptor
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}]; return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
} }
__host__ __device__ constexpr auto GetLengths() const
{
// FIXME: use Tuple of reference instead
return generate_sequence_v2([&](auto I) { return GetLength(I); }, Number<ndim_visible_>{});
}
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; } __host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
......
...@@ -25,6 +25,22 @@ constexpr LoopScheduler make_default_loop_scheduler() ...@@ -25,6 +25,22 @@ constexpr LoopScheduler make_default_loop_scheduler()
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING #endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
} }
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -585,4 +601,361 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ...@@ -585,4 +601,361 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
} }
}; };
// Blockwise gemm supporting
// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
// source buffer
// 3. configurable k index starting position and step size after each FMA/XDL instruction
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_v2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const auto blk_idx =
TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace ck {
template <index_t BlockSize,
typename AccDataType,
typename ThreadMap_M_K, // thread_id to m_k
typename ThreadClusterDesc_M_K,
typename ThreadSliceDesc_M_K>
struct BlockwiseSoftmax
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
using ThreadSliceDesc_M = decltype(
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
using ThreadwiseMaxReduce = ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Max,
false>;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadMap_M_K,
reduce::Max,
false>;
using BlockwiseSumReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadMap_M_K,
reduce::Add,
false>;
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadSliceDesc_M_K,
ThreadSliceDesc_M,
reduce::Add,
false>;
using BufferType = StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MRepeat, true>;
template <typename CThreadBuffer, typename WorkspaceBuffer>
__host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf)
{
// find max value
static_for<0, MRepeat, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
block_sync_lds();
});
// calculate exp for elements, P=exp(s-max)
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM));
});
});
// sum data
static_for<0, MRepeat, 1>{}([&](auto I) {
sum_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I));
block_sync_lds();
});
}
BufferType max_value_buf;
BufferType sum_value_buf;
};
} // namespace ck
...@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction ...@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction
}; };
}; };
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType,
index_t BlockSize,
typename ThreadClusterLengths_M_K,
typename ThreadClusterDesc,
typename OpReduce,
bool PropagateNan,
typename Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
struct PartitionedBlockwiseReduction_v2
{
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
"The product of cluster lengths should be same as BlockSize!");
static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0);
static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1);
static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements");
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<BufferLength_M>{}, Number<BufferLength_K>{}));
static constexpr auto thread_cluster_desc = ThreadClusterDesc{};
template <typename BufferType>
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx);
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));
AccDataType opData1 = work_buffer[offset1];
AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2);
work_buffer(offset1) = opData1;
}
__syncthreads();
});
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
in_out_value = work_buffer[offset];
};
};
// clang-format off // clang-format off
// Assume: // Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data // 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
Acc0ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmSoftmaxGemmPtr =
std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
Acc0ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -1145,9 +1145,22 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1145,9 +1145,22 @@ struct ThreadwiseTensorSliceTransfer_v4
src_desc, src_data_coord); src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector // copy data from src_buf into src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = if constexpr(SrcBuffer::IsDynamicBuffer())
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid); {
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
// apply type convert
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
}
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
...@@ -1184,4 +1197,101 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1184,4 +1197,101 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord src_ref_coord_; SrcCoord src_ref_coord_;
}; };
// Do NOT involve any tensor coordinates with StaticBuffer
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
const ElementwiseOperation& element_op)
: element_op_{element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
}
template <typename SrcSliceOriginIdx,
typename DstSliceOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! SliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
"wrong! Buffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
// scalar per access on each dim
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
});
});
}
ElementwiseOperation element_op_;
};
} // namespace ck } // namespace ck
...@@ -579,7 +579,11 @@ struct MfmaSelector ...@@ -579,7 +579,11 @@ struct MfmaSelector
static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; } static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
}; };
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack> template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
index_t KPack,
bool TransposeC = false>
struct XdlopsGemm struct XdlopsGemm
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -612,6 +616,8 @@ struct XdlopsGemm ...@@ -612,6 +616,8 @@ struct XdlopsGemm
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
} }
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template <typename CDesc_M0_N0_M1_N1_M2_N2> template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
...@@ -627,10 +633,10 @@ struct XdlopsGemm ...@@ -627,10 +633,10 @@ struct XdlopsGemm
make_pass_through_transform(N0), make_pass_through_transform(N0),
make_pass_through_transform(M1), make_pass_through_transform(M1),
make_pass_through_transform(N1), make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
mfma_instr.num_input_blks, Number<mfma_instr.num_input_blks>{},
mfma_instr.group_size)), Number<mfma_instr.group_size>{})),
make_pass_through_transform(mfma_instr.num_threads_per_blk)), make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -645,6 +651,41 @@ struct XdlopsGemm ...@@ -645,6 +651,41 @@ struct XdlopsGemm
Sequence<7>{})); Sequence<7>{}));
} }
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{}),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{}))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{}));
}
template <typename CDesc_G_M0_N0_M1_N1_M2_N2> template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
...@@ -698,7 +739,16 @@ struct XdlopsGemm ...@@ -698,7 +739,16 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread); if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
p_b_wave[k], p_a_wave[k], p_c_thread);
}
}); });
} }
......
...@@ -20,6 +20,29 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -20,6 +20,29 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ constexpr StaticBuffer& operator=(StaticBuffer& y)
{
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
return x;
}
template <typename... Ys>
__host__ __device__ constexpr StaticBuffer& operator=(const Tuple<Ys...>& y)
{
static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same");
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; });
return x;
}
__host__ __device__ constexpr StaticBuffer& operator=(const T& y)
{
StaticBuffer& x = *this;
static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; });
return x;
}
__host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
...@@ -40,10 +63,12 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -40,10 +63,12 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
return base::operator()(i); return base::operator()(i);
} }
__host__ __device__ void Clear() __host__ __device__ void Set(T x)
{ {
static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; });
} }
__host__ __device__ void Clear() { Set(T{0}); }
}; };
// static buffer for vector // static buffer for vector
...@@ -61,6 +86,7 @@ struct StaticBufferTupleOfVector ...@@ -61,6 +86,7 @@ struct StaticBufferTupleOfVector
static constexpr auto s_per_v = Number<ScalarPerVector>{}; static constexpr auto s_per_v = Number<ScalarPerVector>{};
static constexpr auto num_of_v_ = Number<NumOfVector>{}; static constexpr auto num_of_v_ = Number<NumOfVector>{};
static constexpr auto s_per_buf = s_per_v * num_of_v_;
__host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
...@@ -70,6 +96,8 @@ struct StaticBufferTupleOfVector ...@@ -70,6 +96,8 @@ struct StaticBufferTupleOfVector
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
__host__ __device__ static constexpr index_t Size() { return s_per_buf; };
// Get S // Get S
// i is offset of S // i is offset of S
template <index_t I> template <index_t I>
......
...@@ -34,7 +34,10 @@ __host__ __device__ constexpr auto to_multi_index(const T& x) ...@@ -34,7 +34,10 @@ __host__ __device__ constexpr auto to_multi_index(const T& x)
// is the alias of the latter. This is because compiler cannot infer the NSize if // is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize> // using MultiIndex<NSize>
// TODO: how to fix this? // TODO: how to fix this?
template <typename... Ys, typename X> template <
typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -43,7 +46,10 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x) ...@@ -43,7 +46,10 @@ __host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template <typename... Ys, typename X> template <
typename... Ys,
typename X,
enable_if_t<!std::is_integral<X>::value && !std::is_floating_point<X>::value, bool> = false>
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
...@@ -52,7 +58,10 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x) ...@@ -52,7 +58,10 @@ __host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
return y; return y;
} }
template <typename... Xs, typename Y> template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -63,7 +72,10 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y) ...@@ -63,7 +72,10 @@ __host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename Y> template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -74,7 +86,10 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y) ...@@ -74,7 +86,10 @@ __host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename Y> template <
typename... Xs,
typename Y,
enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
...@@ -85,9 +100,11 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y) ...@@ -85,9 +100,11 @@ __host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
return r; return r;
} }
// MultiIndex = index_t * MultiIndex // MultiIndex = scalar * MultiIndex
template <typename... Xs> template <typename... Xs,
__host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
{ {
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
...@@ -96,13 +113,40 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) ...@@ -96,13 +113,40 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r; return r;
} }
// MultiIndex = MultiIndex * index_t // MultiIndex = MultiIndex * scalar
template <typename... Xs> template <typename... Xs,
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, index_t a) typename Y,
enable_if_t<std::is_integral<Y>::value || std::is_floating_point<Y>::value, bool> = false>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a)
{ {
return a * x; return a * x;
} }
namespace mathext {
template <typename... Xs>
__host__ __device__ constexpr auto exp(const Tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::exp(x[i]); });
return r;
}
template <typename... Xs, typename Y>
__host__ __device__ constexpr auto max(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::max(x[i], y[i]); });
return r;
}
} // namespace mathext
template <typename... Xs> template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x) __host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{ {
......
...@@ -16,6 +16,7 @@ namespace host { ...@@ -16,6 +16,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) {
const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; const int K = arg.a_g_m_k_.mDesc.GetLengths()[2];
float v_acc = 0; AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
...@@ -68,10 +69,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -68,10 +69,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator
arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k));
arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n));
v_acc += ck::type_convert<float>(v_a) * ck::type_convert<float>(v_b); v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
float v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
...@@ -81,8 +83,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -81,8 +83,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator
make_ParallelTensorFunctor(f_gmk_gkn_gmn, make_ParallelTensorFunctor(f_gmk_gkn_gmn,
arg.c_g_m_n_.mDesc.GetLengths()[0], arg.c_g_m_n_.mDesc.GetLengths()[0],
arg.c_g_m_n_.mDesc.GetLengths()[1], arg.c_g_m_n_.mDesc.GetLengths()[1],
arg.c_g_m_n_.mDesc.GetLengths()[2])( arg.c_g_m_n_.mDesc.GetLengths()[2])();
std::thread::hardware_concurrency());
return 0; return 0;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
Col,
Row,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment