Commit 27482328 authored by fsx950223's avatar fsx950223
Browse files

format code

parent 24f96b50
......@@ -393,7 +393,8 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 3;
std::size_t flop = 0, num_byte = 0;
for(std::size_t i=0; i<group_count; i++){
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 4 + 1);
int N = 128 * (rand() % 4 + 1);
int K = 64;
......@@ -424,8 +425,8 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
......@@ -463,7 +464,8 @@ int run(int argc, char* argv[])
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
if(i < 4){
if(i < 4)
{
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
......@@ -539,19 +541,24 @@ int run(int argc, char* argv[])
Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
run_attention_fwd_host(
q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
......@@ -567,15 +574,24 @@ int run(int argc, char* argv[])
y_tensors.push_back(y_gs_ms_os);
lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os);
q_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
v_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
v_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
lse_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
v_tensors_device.back()->ToDevice(v_gs_os_ns.data());
......@@ -595,8 +611,7 @@ int run(int argc, char* argv[])
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
}
auto argument = gemm.MakeArgument(
p_q,
auto argument = gemm.MakeArgument(p_q,
p_k,
p_v,
p_y,
......@@ -644,13 +659,15 @@ int run(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
for(int i=0;i<group_count;i++){
for(int i = 0; i < group_count; i++)
{
qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
}
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i=0; i<group_count; i++){
for(std::size_t i = 0; i < group_count; i++)
{
int G0 = v_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1];
......@@ -695,13 +712,19 @@ int run(int argc, char* argv[])
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
......
......@@ -36,7 +36,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
......@@ -739,8 +739,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto vgrad_grid_desc_n_o = DeviceOp::MakeVGradGridDescriptor_N_O(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto ygrad_grid_desc_o0_m_o1 =
DeviceOp::MakeYGradGridDescriptor_O0_M_O1(problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto ygrad_grid_desc_o0_m_o1 = DeviceOp::MakeYGradGridDescriptor_O0_M_O1(
problem_desc.c_gs_ms_gemm1ns_lengths, problem_desc.c_gs_ms_gemm1ns_strides);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
......@@ -889,8 +889,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel =
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<GridwiseGemm,
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
......@@ -963,7 +963,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment