Commit a39dd61f authored by danyao12's avatar danyao12
Browse files

refactor grouped bwd example and fix some bugs

parent 79f3caf8
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define RANGE_HDKO 2 // 0~2
#define RANGE_HDKO 1 // 0~2
#include <iostream>
#include <numeric>
......@@ -678,7 +678,6 @@ int run(int argc, char* argv[])
// = 0
}
// calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
......
......@@ -211,7 +211,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -223,7 +223,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -448,7 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
......@@ -534,7 +533,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType,
......@@ -711,7 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}
grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++)
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
......@@ -895,7 +894,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// for(std::size_t i = 0; i < arg.group_count_; i++)
// {
// const auto K =
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
// all_has_main_k_block_loop &= y;
// some_has_main_k_block_loop |= y;
......@@ -976,7 +976,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
return false;
}
for(std::size_t i = 0; i < arg.group_count_; i++)
for(index_t i = 0; i < arg.group_count_; i++)
{
// TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i];
......@@ -986,7 +986,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
......@@ -1160,7 +1161,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1"
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -703,7 +703,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++)
for(index_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
......@@ -884,10 +884,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false;
for(std::size_t i = 0; i < arg.group_count_; i++)
for(index_t i = 0; i < arg.group_count_; i++)
{
const auto K =
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
all_has_main_k_block_loop &= y;
some_has_main_k_block_loop |= y;
......@@ -968,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
return false;
}
for(std::size_t i = 0; i < arg.group_count_; i++)
for(index_t i = 0; i < arg.group_count_; i++)
{
// TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment