Commit b010b095 authored by aska-0096's avatar aska-0096
Browse files

part2 of previous commit

parent 43777959
......@@ -119,9 +119,9 @@ bool pool_test(bool do_verification,
{N, C, Hi, Wi},
{Y, X},
{N, C, Ho, Wo},
{C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{},
{},
{},
window_strides,
input_left_pads,
input_right_pads,
......
......@@ -161,6 +161,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -283,6 +283,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -327,6 +327,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_cross_attention.inc"
#include "run_cross_attention_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 120;
ck::index_t N = 1000;
ck::index_t K = 64;
ck::index_t O = 128;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1 = 13;
float alpha = 1;
bool input_permute = false;
bool output_permute = true;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 b0; unit: a
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: B1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
G0,
G1,
alpha,
input_permute,
output_permute);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
}
ck::index_t BatchCount = G0 * G1;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}
......@@ -283,6 +283,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp,
CElementOp>;
#include "run_self_attention.inc"
#include "run_self_attention_wmma.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -252,16 +252,16 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
static auto MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
}
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
static auto MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
......@@ -269,8 +269,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
}
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
MakeB1GridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
......@@ -453,14 +453,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CDataType* p_c_grid,
const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumD1Tensor> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
......@@ -835,20 +835,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
C1DEElementwiseOperation c1de_element_op)
{
constexpr auto dimension = NumDimG + NumDimM + NumDimN;
std::array<index_t, dimension> a_gs_ms_ks_lengths_{};
std::array<index_t, dimension> a_gs_ms_ks_strides_{};
std::array<index_t, dimension> b_gs_ns_ks_lengths_{};
std::array<index_t, dimension> b_gs_ns_ks_strides_{};
std::array<index_t, dimension> b1_gs_gemm1ns_gemm1ks_lengths_{}; // b1_gs_os_ns_lengths
std::array<index_t, dimension> b1_gs_gemm1ns_gemm1ks_strides_{}; // b1_gs_os_ns_strides
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return Argument{p_a,
p_b,
p_b1,
p_c,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
a_gs_ms_ks_lengths_,
a_gs_ms_ks_strides_,
b_gs_ns_ks_lengths_,
b_gs_ns_ks_strides_,
b1_gs_gemm1ns_gemm1ks_lengths_, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides_, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths_, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides_, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
......@@ -891,20 +919,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
C1DEElementwiseOperation c1de_element_op) override
{
constexpr auto dimension = NumDimG + NumDimM + NumDimN;
std::array<index_t, dimension> a_gs_ms_ks_lengths_{};
std::array<index_t, dimension> a_gs_ms_ks_strides_{};
std::array<index_t, dimension> b_gs_ns_ks_lengths_{};
std::array<index_t, dimension> b_gs_ns_ks_strides_{};
std::array<index_t, dimension> b1_gs_gemm1ns_gemm1ks_lengths_{}; // b1_gs_os_ns_lengths
std::array<index_t, dimension> b1_gs_gemm1ns_gemm1ks_strides_{}; // b1_gs_os_ns_strides
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin());
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
a_gs_ms_ks_lengths_,
a_gs_ms_ks_strides_,
b_gs_ns_ks_lengths_,
b_gs_ns_ks_strides_,
b1_gs_gemm1ns_gemm1ks_lengths_, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides_, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths_, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides_, // c_gs_ms_os_strides
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
......
......@@ -119,10 +119,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
using GridwiseGemmPipe =
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
AEnableLds,
B0EnableLds,
NumGemmKPrefetchStage,
LoopSched>())>;
LoopSched,
AEnableLds,
B0EnableLds>())>;
__host__ __device__ static constexpr auto MakeABlockDescriptor()
{
......
......@@ -15,10 +15,10 @@ enum struct PipelineVersion
};
template <PipelineVersion PipelineVer,
bool AEnableLds = true,
bool BEnableLds = true,
index_t NumPrefetch = 1,
LoopScheduler LoopSched = LoopScheduler::Default>
LoopScheduler LoopSched = LoopScheduler::Default,
bool AEnableLds = true,
bool BEnableLds = true>
constexpr auto GridwiseGemmPipeline_Selector()
{
if constexpr(PipelineVer == PipelineVersion::v1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment