Commit 27482328 authored by fsx950223's avatar fsx950223
Browse files

format code

parent 24f96b50
......@@ -152,7 +152,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#else
using DeviceGemmInstance =
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
......@@ -313,7 +313,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float K = 128;
float K = 128;
float alpha = 1.f / std::sqrt(K);
bool input_permute = false;
......@@ -351,8 +351,8 @@ int run(int argc, char* argv[])
exit(0);
}
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
......@@ -393,7 +393,8 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 3;
std::size_t flop = 0, num_byte = 0;
for(std::size_t i=0; i<group_count; i++){
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 4 + 1);
int N = 128 * (rand() % 4 + 1);
int K = 64;
......@@ -424,8 +425,8 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
......@@ -453,9 +454,9 @@ int run(int argc, char* argv[])
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(DataType) * M * K + sizeof(DataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) *
size_t(2) * BatchCount +
sizeof(LSEDataType) * M * BatchCount;
sizeof(DataType) * N * O + sizeof(DataType) * M * O) *
size_t(2) * BatchCount +
sizeof(LSEDataType) * M * BatchCount;
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
......@@ -463,7 +464,8 @@ int run(int argc, char* argv[])
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
if(i < 4){
if(i < 4)
{
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
......@@ -539,19 +541,24 @@ int run(int argc, char* argv[])
Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
run_attention_fwd_host(
q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
......@@ -567,15 +574,24 @@ int run(int argc, char* argv[])
y_tensors.push_back(y_gs_ms_os);
lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os);
q_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
v_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
v_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
lse_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
v_tensors_device.back()->ToDevice(v_gs_os_ns.data());
......@@ -585,35 +601,34 @@ int run(int argc, char* argv[])
kgrad_tensors_device.back()->SetZero();
vgrad_tensors_device.back()->SetZero();
ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data());
p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer());
p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer());
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
}
auto argument = gemm.MakeArgument(
p_q,
p_k,
p_v,
p_y,
p_lse,
p_ygrad,
p_qgrad,
p_kgrad,
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
problem_descs,
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{});
auto argument = gemm.MakeArgument(p_q,
p_k,
p_v,
p_y,
p_lse,
p_ygrad,
p_qgrad,
p_kgrad,
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
problem_descs,
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{});
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
......@@ -644,20 +659,22 @@ int run(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
for(int i=0;i<group_count;i++){
for(int i = 0; i < group_count; i++)
{
qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
}
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i=0; i<group_count; i++){
int G0 = v_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
for(std::size_t i = 0; i < group_count; i++)
{
int G0 = v_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1;
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
......@@ -695,13 +712,19 @@ int run(int argc, char* argv[])
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
......
......@@ -36,16 +36,16 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op)
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -739,8 +739,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto vgrad_grid_desc_n_o = DeviceOp::MakeVGradGridDescriptor_N_O(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto ygrad_grid_desc_o0_m_o1 =
DeviceOp::MakeYGradGridDescriptor_O0_M_O1(problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto ygrad_grid_desc_o0_m_o1 = DeviceOp::MakeYGradGridDescriptor_O0_M_O1(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
......@@ -889,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel =
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
......@@ -963,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
......@@ -992,12 +993,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0];
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment