Commit 27482328 authored by fsx950223's avatar fsx950223
Browse files

format code

parent 24f96b50
...@@ -152,7 +152,7 @@ using DeviceGemmInstance = ...@@ -152,7 +152,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else #else
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
...@@ -313,7 +313,7 @@ int run(int argc, char* argv[]) ...@@ -313,7 +313,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float K = 128; float K = 128;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
bool input_permute = false; bool input_permute = false;
...@@ -351,8 +351,8 @@ int run(int argc, char* argv[]) ...@@ -351,8 +351,8 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs; std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
...@@ -393,7 +393,8 @@ int run(int argc, char* argv[]) ...@@ -393,7 +393,8 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> vgrad_tensors_device; std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 3; std::size_t group_count = 3;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
for(std::size_t i=0; i<group_count; i++){ for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 4 + 1); int M = 128 * (rand() % 4 + 1);
int N = 128 * (rand() % 4 + 1); int N = 128 * (rand() % 4 + 1);
int K = 64; int K = 64;
...@@ -424,8 +425,8 @@ int run(int argc, char* argv[]) ...@@ -424,8 +425,8 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...))) // = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^ // ^^^^^^^^^^^^^^^^^^^^^
...@@ -453,9 +454,9 @@ int run(int argc, char* argv[]) ...@@ -453,9 +454,9 @@ int run(int argc, char* argv[])
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(DataType) * M * K + sizeof(DataType) * K * N + num_byte += (sizeof(DataType) * M * K + sizeof(DataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * sizeof(DataType) * N * O + sizeof(DataType) * M * O) *
size_t(2) * BatchCount + size_t(2) * BatchCount +
sizeof(LSEDataType) * M * BatchCount; sizeof(LSEDataType) * M * BatchCount;
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
...@@ -463,7 +464,8 @@ int run(int argc, char* argv[]) ...@@ -463,7 +464,8 @@ int run(int argc, char* argv[])
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
if(i < 4){ if(i < 4)
{
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
...@@ -539,19 +541,24 @@ int run(int argc, char* argv[]) ...@@ -539,19 +541,24 @@ int run(int argc, char* argv[])
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach( q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
k_gs_ns_ks.ForEach( });
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
v_gs_os_ns.ForEach( k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); }); [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m); run_attention_fwd_host(
q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); }); y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); }); [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
...@@ -567,15 +574,24 @@ int run(int argc, char* argv[]) ...@@ -567,15 +574,24 @@ int run(int argc, char* argv[])
y_tensors.push_back(y_gs_ms_os); y_tensors.push_back(y_gs_ms_os);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
q_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); q_tensors_device.emplace_back(
k_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
v_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize())); k_tensors_device.emplace_back(
y_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize())); v_tensors_device.emplace_back(
qgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize())); y_tensors_device.emplace_back(
vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize())); lse_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data()); k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
v_tensors_device.back()->ToDevice(v_gs_os_ns.data()); v_tensors_device.back()->ToDevice(v_gs_os_ns.data());
...@@ -585,35 +601,34 @@ int run(int argc, char* argv[]) ...@@ -585,35 +601,34 @@ int run(int argc, char* argv[])
kgrad_tensors_device.back()->SetZero(); kgrad_tensors_device.back()->SetZero();
vgrad_tensors_device.back()->SetZero(); vgrad_tensors_device.back()->SetZero();
ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data()); ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data());
p_q.push_back(q_tensors_device.back()->GetDeviceBuffer()); p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(k_tensors_device.back()->GetDeviceBuffer()); p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_v.push_back(v_tensors_device.back()->GetDeviceBuffer()); p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer()); p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer()); p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer());
p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer()); p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer());
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer()); p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer()); p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
} }
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(p_q,
p_q, p_k,
p_k, p_v,
p_v, p_y,
p_y, p_lse,
p_lse, p_ygrad,
p_ygrad, p_qgrad,
p_qgrad, p_kgrad,
p_kgrad, p_vgrad,
p_vgrad, {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc1_biases;
{}, // std::array<void*, 1> p_acc1_biases; problem_descs,
problem_descs, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, Scale{alpha},
Scale{alpha}, QKVElementOp{},
QKVElementOp{}, YElementOp{});
YElementOp{});
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
...@@ -644,20 +659,22 @@ int run(int argc, char* argv[]) ...@@ -644,20 +659,22 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
for(int i=0;i<group_count;i++){ for(int i = 0; i < group_count; i++)
{
qgrad_tensors_device[i]->SetZero(); qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero(); kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero(); vgrad_tensors_device[i]->SetZero();
} }
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i=0; i<group_count; i++){ for(std::size_t i = 0; i < group_count; i++)
{
int G0 = v_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1]; int G0 = v_tensors[i].GetLengths()[0];
int O = v_tensors[i].GetLengths()[2]; int G1 = v_tensors[i].GetLengths()[1];
int N = v_tensors[i].GetLengths()[3]; int O = v_tensors[i].GetLengths()[2];
int M = q_tensors[i].GetLengths()[2]; int N = v_tensors[i].GetLengths()[3];
int K = q_tensors[i].GetLengths()[3]; int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
...@@ -695,13 +712,19 @@ int run(int argc, char* argv[]) ...@@ -695,13 +712,19 @@ int run(int argc, char* argv[])
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides()); Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides()); Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides()); Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides()); v_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
......
...@@ -36,16 +36,16 @@ template <typename GridwiseGemm, ...@@ -36,16 +36,16 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2( kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -739,8 +739,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -739,8 +739,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto vgrad_grid_desc_n_o = DeviceOp::MakeVGradGridDescriptor_N_O( const auto vgrad_grid_desc_n_o = DeviceOp::MakeVGradGridDescriptor_N_O(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides); problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto ygrad_grid_desc_o0_m_o1 = const auto ygrad_grid_desc_o0_m_o1 = DeviceOp::MakeYGradGridDescriptor_O0_M_O1(
DeviceOp::MakeYGradGridDescriptor_O0_M_O1(problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides); problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K( const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
...@@ -889,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -889,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -963,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -963,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
...@@ -992,12 +993,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -992,12 +993,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1] ? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0]; : device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2 const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1] ? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0]; : device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1] ? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0]; : device_arg.b1_nz_kz_strides_[0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment